1from bisect import bisect_left
2from bisect import bisect_right
3from contextlib import contextmanager
4from copy import deepcopy
5from functools import wraps
6from inspect import isclass
7import calendar
8import collections
9import datetime
10import decimal
11import hashlib
12import itertools
13import logging
14import operator
15import re
16import socket
17import struct
18import sys
19import threading
20import time
21import uuid
22import warnings
23try:
24    from collections.abc import Mapping
25except ImportError:
26    from collections import Mapping
27
28try:
29    from pysqlite3 import dbapi2 as pysq3
30except ImportError:
31    try:
32        from pysqlite2 import dbapi2 as pysq3
33    except ImportError:
34        pysq3 = None
35try:
36    import sqlite3
37except ImportError:
38    sqlite3 = pysq3
39else:
40    if pysq3 and pysq3.sqlite_version_info >= sqlite3.sqlite_version_info:
41        sqlite3 = pysq3
42try:
43    from psycopg2cffi import compat
44    compat.register()
45except ImportError:
46    pass
47try:
48    import psycopg2
49    from psycopg2 import extensions as pg_extensions
50    try:
51        from psycopg2 import errors as pg_errors
52    except ImportError:
53        pg_errors = None
54except ImportError:
55    psycopg2 = pg_errors = None
56try:
57    from psycopg2.extras import register_uuid as pg_register_uuid
58    pg_register_uuid()
59except Exception:
60    pass
61
62mysql_passwd = False
63try:
64    import pymysql as mysql
65except ImportError:
66    try:
67        import MySQLdb as mysql
68        mysql_passwd = True
69    except ImportError:
70        mysql = None
71
72
73__version__ = '3.14.4'
74__all__ = [
75    'AsIs',
76    'AutoField',
77    'BareField',
78    'BigAutoField',
79    'BigBitField',
80    'BigIntegerField',
81    'BinaryUUIDField',
82    'BitField',
83    'BlobField',
84    'BooleanField',
85    'Case',
86    'Cast',
87    'CharField',
88    'Check',
89    'chunked',
90    'Column',
91    'CompositeKey',
92    'Context',
93    'Database',
94    'DatabaseError',
95    'DatabaseProxy',
96    'DataError',
97    'DateField',
98    'DateTimeField',
99    'DecimalField',
100    'DeferredForeignKey',
101    'DeferredThroughModel',
102    'DJANGO_MAP',
103    'DoesNotExist',
104    'DoubleField',
105    'DQ',
106    'EXCLUDED',
107    'Field',
108    'FixedCharField',
109    'FloatField',
110    'fn',
111    'ForeignKeyField',
112    'IdentityField',
113    'ImproperlyConfigured',
114    'Index',
115    'IntegerField',
116    'IntegrityError',
117    'InterfaceError',
118    'InternalError',
119    'IPField',
120    'JOIN',
121    'ManyToManyField',
122    'Model',
123    'ModelIndex',
124    'MySQLDatabase',
125    'NotSupportedError',
126    'OP',
127    'OperationalError',
128    'PostgresqlDatabase',
129    'PrimaryKeyField',  # XXX: Deprecated, change to AutoField.
130    'prefetch',
131    'ProgrammingError',
132    'Proxy',
133    'QualifiedNames',
134    'SchemaManager',
135    'SmallIntegerField',
136    'Select',
137    'SQL',
138    'SqliteDatabase',
139    'Table',
140    'TextField',
141    'TimeField',
142    'TimestampField',
143    'Tuple',
144    'UUIDField',
145    'Value',
146    'ValuesList',
147    'Window',
148]
149
150try:  # Python 2.7+
151    from logging import NullHandler
152except ImportError:
153    class NullHandler(logging.Handler):
154        def emit(self, record):
155            pass
156
157logger = logging.getLogger('peewee')
158logger.addHandler(NullHandler())
159
160
161if sys.version_info[0] == 2:
162    text_type = unicode
163    bytes_type = str
164    buffer_type = buffer
165    izip_longest = itertools.izip_longest
166    callable_ = callable
167    multi_types = (list, tuple, frozenset, set)
168    exec('def reraise(tp, value, tb=None): raise tp, value, tb')
169    def print_(s):
170        sys.stdout.write(s)
171        sys.stdout.write('\n')
172else:
173    import builtins
174    try:
175        from collections.abc import Callable
176    except ImportError:
177        from collections import Callable
178    from functools import reduce
179    callable_ = lambda c: isinstance(c, Callable)
180    text_type = str
181    bytes_type = bytes
182    buffer_type = memoryview
183    basestring = str
184    long = int
185    multi_types = (list, tuple, frozenset, set, range)
186    print_ = getattr(builtins, 'print')
187    izip_longest = itertools.zip_longest
188    def reraise(tp, value, tb=None):
189        if value.__traceback__ is not tb:
190            raise value.with_traceback(tb)
191        raise value
192
193
194if sqlite3:
195    sqlite3.register_adapter(decimal.Decimal, str)
196    sqlite3.register_adapter(datetime.date, str)
197    sqlite3.register_adapter(datetime.time, str)
198    __sqlite_version__ = sqlite3.sqlite_version_info
199else:
200    __sqlite_version__ = (0, 0, 0)
201
202
203__date_parts__ = set(('year', 'month', 'day', 'hour', 'minute', 'second'))
204
205# Sqlite does not support the `date_part` SQL function, so we will define an
206# implementation in python.
207__sqlite_datetime_formats__ = (
208    '%Y-%m-%d %H:%M:%S',
209    '%Y-%m-%d %H:%M:%S.%f',
210    '%Y-%m-%d',
211    '%H:%M:%S',
212    '%H:%M:%S.%f',
213    '%H:%M')
214
215__sqlite_date_trunc__ = {
216    'year': '%Y-01-01 00:00:00',
217    'month': '%Y-%m-01 00:00:00',
218    'day': '%Y-%m-%d 00:00:00',
219    'hour': '%Y-%m-%d %H:00:00',
220    'minute': '%Y-%m-%d %H:%M:00',
221    'second': '%Y-%m-%d %H:%M:%S'}
222
223__mysql_date_trunc__ = __sqlite_date_trunc__.copy()
224__mysql_date_trunc__['minute'] = '%Y-%m-%d %H:%i:00'
225__mysql_date_trunc__['second'] = '%Y-%m-%d %H:%i:%S'
226
227def _sqlite_date_part(lookup_type, datetime_string):
228    assert lookup_type in __date_parts__
229    if not datetime_string:
230        return
231    dt = format_date_time(datetime_string, __sqlite_datetime_formats__)
232    return getattr(dt, lookup_type)
233
234def _sqlite_date_trunc(lookup_type, datetime_string):
235    assert lookup_type in __sqlite_date_trunc__
236    if not datetime_string:
237        return
238    dt = format_date_time(datetime_string, __sqlite_datetime_formats__)
239    return dt.strftime(__sqlite_date_trunc__[lookup_type])
240
241
242def __deprecated__(s):
243    warnings.warn(s, DeprecationWarning)
244
245
246class attrdict(dict):
247    def __getattr__(self, attr):
248        try:
249            return self[attr]
250        except KeyError:
251            raise AttributeError(attr)
252    def __setattr__(self, attr, value): self[attr] = value
253    def __iadd__(self, rhs): self.update(rhs); return self
254    def __add__(self, rhs): d = attrdict(self); d.update(rhs); return d
255
256SENTINEL = object()
257
258#: Operations for use in SQL expressions.
259OP = attrdict(
260    AND='AND',
261    OR='OR',
262    ADD='+',
263    SUB='-',
264    MUL='*',
265    DIV='/',
266    BIN_AND='&',
267    BIN_OR='|',
268    XOR='#',
269    MOD='%',
270    EQ='=',
271    LT='<',
272    LTE='<=',
273    GT='>',
274    GTE='>=',
275    NE='!=',
276    IN='IN',
277    NOT_IN='NOT IN',
278    IS='IS',
279    IS_NOT='IS NOT',
280    LIKE='LIKE',
281    ILIKE='ILIKE',
282    BETWEEN='BETWEEN',
283    REGEXP='REGEXP',
284    IREGEXP='IREGEXP',
285    CONCAT='||',
286    BITWISE_NEGATION='~')
287
288# To support "django-style" double-underscore filters, create a mapping between
289# operation name and operation code, e.g. "__eq" == OP.EQ.
290DJANGO_MAP = attrdict({
291    'eq': operator.eq,
292    'lt': operator.lt,
293    'lte': operator.le,
294    'gt': operator.gt,
295    'gte': operator.ge,
296    'ne': operator.ne,
297    'in': operator.lshift,
298    'is': lambda l, r: Expression(l, OP.IS, r),
299    'like': lambda l, r: Expression(l, OP.LIKE, r),
300    'ilike': lambda l, r: Expression(l, OP.ILIKE, r),
301    'regexp': lambda l, r: Expression(l, OP.REGEXP, r),
302})
303
304#: Mapping of field type to the data-type supported by the database. Databases
305#: may override or add to this list.
306FIELD = attrdict(
307    AUTO='INTEGER',
308    BIGAUTO='BIGINT',
309    BIGINT='BIGINT',
310    BLOB='BLOB',
311    BOOL='SMALLINT',
312    CHAR='CHAR',
313    DATE='DATE',
314    DATETIME='DATETIME',
315    DECIMAL='DECIMAL',
316    DEFAULT='',
317    DOUBLE='REAL',
318    FLOAT='REAL',
319    INT='INTEGER',
320    SMALLINT='SMALLINT',
321    TEXT='TEXT',
322    TIME='TIME',
323    UUID='TEXT',
324    UUIDB='BLOB',
325    VARCHAR='VARCHAR')
326
327#: Join helpers (for convenience) -- all join types are supported, this object
328#: is just to help avoid introducing errors by using strings everywhere.
329JOIN = attrdict(
330    INNER='INNER JOIN',
331    LEFT_OUTER='LEFT OUTER JOIN',
332    RIGHT_OUTER='RIGHT OUTER JOIN',
333    FULL='FULL JOIN',
334    FULL_OUTER='FULL OUTER JOIN',
335    CROSS='CROSS JOIN',
336    NATURAL='NATURAL JOIN',
337    LATERAL='LATERAL',
338    LEFT_LATERAL='LEFT JOIN LATERAL')
339
340# Row representations.
341ROW = attrdict(
342    TUPLE=1,
343    DICT=2,
344    NAMED_TUPLE=3,
345    CONSTRUCTOR=4,
346    MODEL=5)
347
348SCOPE_NORMAL = 1
349SCOPE_SOURCE = 2
350SCOPE_VALUES = 4
351SCOPE_CTE = 8
352SCOPE_COLUMN = 16
353
354# Rules for parentheses around subqueries in compound select.
355CSQ_PARENTHESES_NEVER = 0
356CSQ_PARENTHESES_ALWAYS = 1
357CSQ_PARENTHESES_UNNESTED = 2
358
359# Regular expressions used to convert class names to snake-case table names.
360# First regex handles acronym followed by word or initial lower-word followed
361# by a capitalized word. e.g. APIResponse -> API_Response / fooBar -> foo_Bar.
362# Second regex handles the normal case of two title-cased words.
363SNAKE_CASE_STEP1 = re.compile('(.)_*([A-Z][a-z]+)')
364SNAKE_CASE_STEP2 = re.compile('([a-z0-9])_*([A-Z])')
365
366# Helper functions that are used in various parts of the codebase.
367MODEL_BASE = '_metaclass_helper_'
368
369def with_metaclass(meta, base=object):
370    return meta(MODEL_BASE, (base,), {})
371
372def merge_dict(source, overrides):
373    merged = source.copy()
374    if overrides:
375        merged.update(overrides)
376    return merged
377
378def quote(path, quote_chars):
379    if len(path) == 1:
380        return path[0].join(quote_chars)
381    return '.'.join([part.join(quote_chars) for part in path])
382
383is_model = lambda o: isclass(o) and issubclass(o, Model)
384
385def ensure_tuple(value):
386    if value is not None:
387        return value if isinstance(value, (list, tuple)) else (value,)
388
389def ensure_entity(value):
390    if value is not None:
391        return value if isinstance(value, Node) else Entity(value)
392
393def make_snake_case(s):
394    first = SNAKE_CASE_STEP1.sub(r'\1_\2', s)
395    return SNAKE_CASE_STEP2.sub(r'\1_\2', first).lower()
396
397def chunked(it, n):
398    marker = object()
399    for group in (list(g) for g in izip_longest(*[iter(it)] * n,
400                                                fillvalue=marker)):
401        if group[-1] is marker:
402            del group[group.index(marker):]
403        yield group
404
405
406class _callable_context_manager(object):
407    def __call__(self, fn):
408        @wraps(fn)
409        def inner(*args, **kwargs):
410            with self:
411                return fn(*args, **kwargs)
412        return inner
413
414
415class Proxy(object):
416    """
417    Create a proxy or placeholder for another object.
418    """
419    __slots__ = ('obj', '_callbacks')
420
421    def __init__(self):
422        self._callbacks = []
423        self.initialize(None)
424
425    def initialize(self, obj):
426        self.obj = obj
427        for callback in self._callbacks:
428            callback(obj)
429
430    def attach_callback(self, callback):
431        self._callbacks.append(callback)
432        return callback
433
434    def passthrough(method):
435        def inner(self, *args, **kwargs):
436            if self.obj is None:
437                raise AttributeError('Cannot use uninitialized Proxy.')
438            return getattr(self.obj, method)(*args, **kwargs)
439        return inner
440
441    # Allow proxy to be used as a context-manager.
442    __enter__ = passthrough('__enter__')
443    __exit__ = passthrough('__exit__')
444
445    def __getattr__(self, attr):
446        if self.obj is None:
447            raise AttributeError('Cannot use uninitialized Proxy.')
448        return getattr(self.obj, attr)
449
450    def __setattr__(self, attr, value):
451        if attr not in self.__slots__:
452            raise AttributeError('Cannot set attribute on proxy.')
453        return super(Proxy, self).__setattr__(attr, value)
454
455
456class DatabaseProxy(Proxy):
457    """
458    Proxy implementation specifically for proxying `Database` objects.
459    """
460    def connection_context(self):
461        return ConnectionContext(self)
462    def atomic(self, *args, **kwargs):
463        return _atomic(self, *args, **kwargs)
464    def manual_commit(self):
465        return _manual(self)
466    def transaction(self, *args, **kwargs):
467        return _transaction(self, *args, **kwargs)
468    def savepoint(self):
469        return _savepoint(self)
470
471
472class ModelDescriptor(object): pass
473
474
475# SQL Generation.
476
477
478class AliasManager(object):
479    __slots__ = ('_counter', '_current_index', '_mapping')
480
481    def __init__(self):
482        # A list of dictionaries containing mappings at various depths.
483        self._counter = 0
484        self._current_index = 0
485        self._mapping = []
486        self.push()
487
488    @property
489    def mapping(self):
490        return self._mapping[self._current_index - 1]
491
492    def add(self, source):
493        if source not in self.mapping:
494            self._counter += 1
495            self[source] = 't%d' % self._counter
496        return self.mapping[source]
497
498    def get(self, source, any_depth=False):
499        if any_depth:
500            for idx in reversed(range(self._current_index)):
501                if source in self._mapping[idx]:
502                    return self._mapping[idx][source]
503        return self.add(source)
504
505    def __getitem__(self, source):
506        return self.get(source)
507
508    def __setitem__(self, source, alias):
509        self.mapping[source] = alias
510
511    def push(self):
512        self._current_index += 1
513        if self._current_index > len(self._mapping):
514            self._mapping.append({})
515
516    def pop(self):
517        if self._current_index == 1:
518            raise ValueError('Cannot pop() from empty alias manager.')
519        self._current_index -= 1
520
521
522class State(collections.namedtuple('_State', ('scope', 'parentheses',
523                                              'settings'))):
524    def __new__(cls, scope=SCOPE_NORMAL, parentheses=False, **kwargs):
525        return super(State, cls).__new__(cls, scope, parentheses, kwargs)
526
527    def __call__(self, scope=None, parentheses=None, **kwargs):
528        # Scope and settings are "inherited" (parentheses is not, however).
529        scope = self.scope if scope is None else scope
530
531        # Try to avoid unnecessary dict copying.
532        if kwargs and self.settings:
533            settings = self.settings.copy()  # Copy original settings dict.
534            settings.update(kwargs)  # Update copy with overrides.
535        elif kwargs:
536            settings = kwargs
537        else:
538            settings = self.settings
539        return State(scope, parentheses, **settings)
540
541    def __getattr__(self, attr_name):
542        return self.settings.get(attr_name)
543
544
545def __scope_context__(scope):
546    @contextmanager
547    def inner(self, **kwargs):
548        with self(scope=scope, **kwargs):
549            yield self
550    return inner
551
552
553class Context(object):
554    __slots__ = ('stack', '_sql', '_values', 'alias_manager', 'state')
555
556    def __init__(self, **settings):
557        self.stack = []
558        self._sql = []
559        self._values = []
560        self.alias_manager = AliasManager()
561        self.state = State(**settings)
562
563    def as_new(self):
564        return Context(**self.state.settings)
565
566    def column_sort_key(self, item):
567        return item[0].get_sort_key(self)
568
569    @property
570    def scope(self):
571        return self.state.scope
572
573    @property
574    def parentheses(self):
575        return self.state.parentheses
576
577    @property
578    def subquery(self):
579        return self.state.subquery
580
581    def __call__(self, **overrides):
582        if overrides and overrides.get('scope') == self.scope:
583            del overrides['scope']
584
585        self.stack.append(self.state)
586        self.state = self.state(**overrides)
587        return self
588
589    scope_normal = __scope_context__(SCOPE_NORMAL)
590    scope_source = __scope_context__(SCOPE_SOURCE)
591    scope_values = __scope_context__(SCOPE_VALUES)
592    scope_cte = __scope_context__(SCOPE_CTE)
593    scope_column = __scope_context__(SCOPE_COLUMN)
594
595    def __enter__(self):
596        if self.parentheses:
597            self.literal('(')
598        return self
599
600    def __exit__(self, exc_type, exc_val, exc_tb):
601        if self.parentheses:
602            self.literal(')')
603        self.state = self.stack.pop()
604
605    @contextmanager
606    def push_alias(self):
607        self.alias_manager.push()
608        yield
609        self.alias_manager.pop()
610
611    def sql(self, obj):
612        if isinstance(obj, (Node, Context)):
613            return obj.__sql__(self)
614        elif is_model(obj):
615            return obj._meta.table.__sql__(self)
616        else:
617            return self.sql(Value(obj))
618
619    def literal(self, keyword):
620        self._sql.append(keyword)
621        return self
622
623    def value(self, value, converter=None, add_param=True):
624        if converter:
625            value = converter(value)
626        elif converter is None and self.state.converter:
627            # Explicitly check for None so that "False" can be used to signify
628            # that no conversion should be applied.
629            value = self.state.converter(value)
630
631        if isinstance(value, Node):
632            with self(converter=None):
633                return self.sql(value)
634        elif is_model(value):
635            # Under certain circumstances, we could end-up treating a model-
636            # class itself as a value. This check ensures that we drop the
637            # table alias into the query instead of trying to parameterize a
638            # model (for instance, passing a model as a function argument).
639            with self.scope_column():
640                return self.sql(value)
641
642        self._values.append(value)
643        return self.literal(self.state.param or '?') if add_param else self
644
645    def __sql__(self, ctx):
646        ctx._sql.extend(self._sql)
647        ctx._values.extend(self._values)
648        return ctx
649
650    def parse(self, node):
651        return self.sql(node).query()
652
653    def query(self):
654        return ''.join(self._sql), self._values
655
656
657def query_to_string(query):
658    # NOTE: this function is not exported by default as it might be misused --
659    # and this misuse could lead to sql injection vulnerabilities. This
660    # function is intended for debugging or logging purposes ONLY.
661    db = getattr(query, '_database', None)
662    if db is not None:
663        ctx = db.get_sql_context()
664    else:
665        ctx = Context()
666
667    sql, params = ctx.sql(query).query()
668    if not params:
669        return sql
670
671    param = ctx.state.param or '?'
672    if param == '?':
673        sql = sql.replace('?', '%s')
674
675    return sql % tuple(map(_query_val_transform, params))
676
677def _query_val_transform(v):
678    # Interpolate parameters.
679    if isinstance(v, (text_type, datetime.datetime, datetime.date,
680                      datetime.time)):
681        v = "'%s'" % v
682    elif isinstance(v, bytes_type):
683        try:
684            v = v.decode('utf8')
685        except UnicodeDecodeError:
686            v = v.decode('raw_unicode_escape')
687        v = "'%s'" % v
688    elif isinstance(v, int):
689        v = '%s' % int(v)  # Also handles booleans -> 1 or 0.
690    elif v is None:
691        v = 'NULL'
692    else:
693        v = str(v)
694    return v
695
696
697# AST.
698
699
700class Node(object):
701    _coerce = True
702
703    def clone(self):
704        obj = self.__class__.__new__(self.__class__)
705        obj.__dict__ = self.__dict__.copy()
706        return obj
707
708    def __sql__(self, ctx):
709        raise NotImplementedError
710
711    @staticmethod
712    def copy(method):
713        def inner(self, *args, **kwargs):
714            clone = self.clone()
715            method(clone, *args, **kwargs)
716            return clone
717        return inner
718
719    def coerce(self, _coerce=True):
720        if _coerce != self._coerce:
721            clone = self.clone()
722            clone._coerce = _coerce
723            return clone
724        return self
725
726    def is_alias(self):
727        return False
728
729    def unwrap(self):
730        return self
731
732
733class ColumnFactory(object):
734    __slots__ = ('node',)
735
736    def __init__(self, node):
737        self.node = node
738
739    def __getattr__(self, attr):
740        return Column(self.node, attr)
741
742
743class _DynamicColumn(object):
744    __slots__ = ()
745
746    def __get__(self, instance, instance_type=None):
747        if instance is not None:
748            return ColumnFactory(instance)  # Implements __getattr__().
749        return self
750
751
752class _ExplicitColumn(object):
753    __slots__ = ()
754
755    def __get__(self, instance, instance_type=None):
756        if instance is not None:
757            raise AttributeError(
758                '%s specifies columns explicitly, and does not support '
759                'dynamic column lookups.' % instance)
760        return self
761
762
763class Source(Node):
764    c = _DynamicColumn()
765
766    def __init__(self, alias=None):
767        super(Source, self).__init__()
768        self._alias = alias
769
770    @Node.copy
771    def alias(self, name):
772        self._alias = name
773
774    def select(self, *columns):
775        if not columns:
776            columns = (SQL('*'),)
777        return Select((self,), columns)
778
779    def join(self, dest, join_type=JOIN.INNER, on=None):
780        return Join(self, dest, join_type, on)
781
782    def left_outer_join(self, dest, on=None):
783        return Join(self, dest, JOIN.LEFT_OUTER, on)
784
785    def cte(self, name, recursive=False, columns=None, materialized=None):
786        return CTE(name, self, recursive=recursive, columns=columns,
787                   materialized=materialized)
788
789    def get_sort_key(self, ctx):
790        if self._alias:
791            return (self._alias,)
792        return (ctx.alias_manager[self],)
793
794    def apply_alias(self, ctx):
795        # If we are defining the source, include the "AS alias" declaration. An
796        # alias is created for the source if one is not already defined.
797        if ctx.scope == SCOPE_SOURCE:
798            if self._alias:
799                ctx.alias_manager[self] = self._alias
800            ctx.literal(' AS ').sql(Entity(ctx.alias_manager[self]))
801        return ctx
802
803    def apply_column(self, ctx):
804        if self._alias:
805            ctx.alias_manager[self] = self._alias
806        return ctx.sql(Entity(ctx.alias_manager[self]))
807
808
809class _HashableSource(object):
810    def __init__(self, *args, **kwargs):
811        super(_HashableSource, self).__init__(*args, **kwargs)
812        self._update_hash()
813
814    @Node.copy
815    def alias(self, name):
816        self._alias = name
817        self._update_hash()
818
819    def _update_hash(self):
820        self._hash = self._get_hash()
821
822    def _get_hash(self):
823        return hash((self.__class__, self._path, self._alias))
824
825    def __hash__(self):
826        return self._hash
827
828    def __eq__(self, other):
829        if isinstance(other, _HashableSource):
830            return self._hash == other._hash
831        return Expression(self, OP.EQ, other)
832
833    def __ne__(self, other):
834        if isinstance(other, _HashableSource):
835            return self._hash != other._hash
836        return Expression(self, OP.NE, other)
837
838    def _e(op):
839        def inner(self, rhs):
840            return Expression(self, op, rhs)
841        return inner
842    __lt__ = _e(OP.LT)
843    __le__ = _e(OP.LTE)
844    __gt__ = _e(OP.GT)
845    __ge__ = _e(OP.GTE)
846
847
848def __bind_database__(meth):
849    @wraps(meth)
850    def inner(self, *args, **kwargs):
851        result = meth(self, *args, **kwargs)
852        if self._database:
853            return result.bind(self._database)
854        return result
855    return inner
856
857
858def __join__(join_type=JOIN.INNER, inverted=False):
859    def method(self, other):
860        if inverted:
861            self, other = other, self
862        return Join(self, other, join_type=join_type)
863    return method
864
865
866class BaseTable(Source):
867    __and__ = __join__(JOIN.INNER)
868    __add__ = __join__(JOIN.LEFT_OUTER)
869    __sub__ = __join__(JOIN.RIGHT_OUTER)
870    __or__ = __join__(JOIN.FULL_OUTER)
871    __mul__ = __join__(JOIN.CROSS)
872    __rand__ = __join__(JOIN.INNER, inverted=True)
873    __radd__ = __join__(JOIN.LEFT_OUTER, inverted=True)
874    __rsub__ = __join__(JOIN.RIGHT_OUTER, inverted=True)
875    __ror__ = __join__(JOIN.FULL_OUTER, inverted=True)
876    __rmul__ = __join__(JOIN.CROSS, inverted=True)
877
878
879class _BoundTableContext(_callable_context_manager):
880    def __init__(self, table, database):
881        self.table = table
882        self.database = database
883
884    def __enter__(self):
885        self._orig_database = self.table._database
886        self.table.bind(self.database)
887        if self.table._model is not None:
888            self.table._model.bind(self.database)
889        return self.table
890
891    def __exit__(self, exc_type, exc_val, exc_tb):
892        self.table.bind(self._orig_database)
893        if self.table._model is not None:
894            self.table._model.bind(self._orig_database)
895
896
897class Table(_HashableSource, BaseTable):
898    def __init__(self, name, columns=None, primary_key=None, schema=None,
899                 alias=None, _model=None, _database=None):
900        self.__name__ = name
901        self._columns = columns
902        self._primary_key = primary_key
903        self._schema = schema
904        self._path = (schema, name) if schema else (name,)
905        self._model = _model
906        self._database = _database
907        super(Table, self).__init__(alias=alias)
908
909        # Allow tables to restrict what columns are available.
910        if columns is not None:
911            self.c = _ExplicitColumn()
912            for column in columns:
913                setattr(self, column, Column(self, column))
914
915        if primary_key:
916            col_src = self if self._columns else self.c
917            self.primary_key = getattr(col_src, primary_key)
918        else:
919            self.primary_key = None
920
921    def clone(self):
922        # Ensure a deep copy of the column instances.
923        return Table(
924            self.__name__,
925            columns=self._columns,
926            primary_key=self._primary_key,
927            schema=self._schema,
928            alias=self._alias,
929            _model=self._model,
930            _database=self._database)
931
932    def bind(self, database=None):
933        self._database = database
934        return self
935
936    def bind_ctx(self, database=None):
937        return _BoundTableContext(self, database)
938
939    def _get_hash(self):
940        return hash((self.__class__, self._path, self._alias, self._model))
941
942    @__bind_database__
943    def select(self, *columns):
944        if not columns and self._columns:
945            columns = [Column(self, column) for column in self._columns]
946        return Select((self,), columns)
947
948    @__bind_database__
949    def insert(self, insert=None, columns=None, **kwargs):
950        if kwargs:
951            insert = {} if insert is None else insert
952            src = self if self._columns else self.c
953            for key, value in kwargs.items():
954                insert[getattr(src, key)] = value
955        return Insert(self, insert=insert, columns=columns)
956
957    @__bind_database__
958    def replace(self, insert=None, columns=None, **kwargs):
959        return (self
960                .insert(insert=insert, columns=columns)
961                .on_conflict('REPLACE'))
962
963    @__bind_database__
964    def update(self, update=None, **kwargs):
965        if kwargs:
966            update = {} if update is None else update
967            for key, value in kwargs.items():
968                src = self if self._columns else self.c
969                update[getattr(src, key)] = value
970        return Update(self, update=update)
971
972    @__bind_database__
973    def delete(self):
974        return Delete(self)
975
976    def __sql__(self, ctx):
977        if ctx.scope == SCOPE_VALUES:
978            # Return the quoted table name.
979            return ctx.sql(Entity(*self._path))
980
981        if self._alias:
982            ctx.alias_manager[self] = self._alias
983
984        if ctx.scope == SCOPE_SOURCE:
985            # Define the table and its alias.
986            return self.apply_alias(ctx.sql(Entity(*self._path)))
987        else:
988            # Refer to the table using the alias.
989            return self.apply_column(ctx)
990
991
992class Join(BaseTable):
993    def __init__(self, lhs, rhs, join_type=JOIN.INNER, on=None, alias=None):
994        super(Join, self).__init__(alias=alias)
995        self.lhs = lhs
996        self.rhs = rhs
997        self.join_type = join_type
998        self._on = on
999
1000    def on(self, predicate):
1001        self._on = predicate
1002        return self
1003
1004    def __sql__(self, ctx):
1005        (ctx
1006         .sql(self.lhs)
1007         .literal(' %s ' % self.join_type)
1008         .sql(self.rhs))
1009        if self._on is not None:
1010            ctx.literal(' ON ').sql(self._on)
1011        return ctx
1012
1013
1014class ValuesList(_HashableSource, BaseTable):
1015    def __init__(self, values, columns=None, alias=None):
1016        self._values = values
1017        self._columns = columns
1018        super(ValuesList, self).__init__(alias=alias)
1019
1020    def _get_hash(self):
1021        return hash((self.__class__, id(self._values), self._alias))
1022
1023    @Node.copy
1024    def columns(self, *names):
1025        self._columns = names
1026
1027    def __sql__(self, ctx):
1028        if self._alias:
1029            ctx.alias_manager[self] = self._alias
1030
1031        if ctx.scope == SCOPE_SOURCE or ctx.scope == SCOPE_NORMAL:
1032            with ctx(parentheses=not ctx.parentheses):
1033                ctx = (ctx
1034                       .literal('VALUES ')
1035                       .sql(CommaNodeList([
1036                           EnclosedNodeList(row) for row in self._values])))
1037
1038            if ctx.scope == SCOPE_SOURCE:
1039                ctx.literal(' AS ').sql(Entity(ctx.alias_manager[self]))
1040                if self._columns:
1041                    entities = [Entity(c) for c in self._columns]
1042                    ctx.sql(EnclosedNodeList(entities))
1043        else:
1044            ctx.sql(Entity(ctx.alias_manager[self]))
1045
1046        return ctx
1047
1048
1049class CTE(_HashableSource, Source):
1050    def __init__(self, name, query, recursive=False, columns=None,
1051                 materialized=None):
1052        self._alias = name
1053        self._query = query
1054        self._recursive = recursive
1055        self._materialized = materialized
1056        if columns is not None:
1057            columns = [Entity(c) if isinstance(c, basestring) else c
1058                       for c in columns]
1059        self._columns = columns
1060        query._cte_list = ()
1061        super(CTE, self).__init__(alias=name)
1062
1063    def select_from(self, *columns):
1064        if not columns:
1065            raise ValueError('select_from() must specify one or more columns '
1066                             'from the CTE to select.')
1067
1068        query = (Select((self,), columns)
1069                 .with_cte(self)
1070                 .bind(self._query._database))
1071        try:
1072            query = query.objects(self._query.model)
1073        except AttributeError:
1074            pass
1075        return query
1076
1077    def _get_hash(self):
1078        return hash((self.__class__, self._alias, id(self._query)))
1079
1080    def union_all(self, rhs):
1081        clone = self._query.clone()
1082        return CTE(self._alias, clone + rhs, self._recursive, self._columns)
1083    __add__ = union_all
1084
1085    def union(self, rhs):
1086        clone = self._query.clone()
1087        return CTE(self._alias, clone | rhs, self._recursive, self._columns)
1088    __or__ = union
1089
1090    def __sql__(self, ctx):
1091        if ctx.scope != SCOPE_CTE:
1092            return ctx.sql(Entity(self._alias))
1093
1094        with ctx.push_alias():
1095            ctx.alias_manager[self] = self._alias
1096            ctx.sql(Entity(self._alias))
1097
1098            if self._columns:
1099                ctx.literal(' ').sql(EnclosedNodeList(self._columns))
1100            ctx.literal(' AS ')
1101
1102            if self._materialized:
1103                ctx.literal('MATERIALIZED ')
1104            elif self._materialized is False:
1105                ctx.literal('NOT MATERIALIZED ')
1106
1107            with ctx.scope_normal(parentheses=True):
1108                ctx.sql(self._query)
1109        return ctx
1110
1111
1112class ColumnBase(Node):
1113    _converter = None
1114
1115    @Node.copy
1116    def converter(self, converter=None):
1117        self._converter = converter
1118
1119    def alias(self, alias):
1120        if alias:
1121            return Alias(self, alias)
1122        return self
1123
1124    def unalias(self):
1125        return self
1126
1127    def cast(self, as_type):
1128        return Cast(self, as_type)
1129
1130    def asc(self, collation=None, nulls=None):
1131        return Asc(self, collation=collation, nulls=nulls)
1132    __pos__ = asc
1133
1134    def desc(self, collation=None, nulls=None):
1135        return Desc(self, collation=collation, nulls=nulls)
1136    __neg__ = desc
1137
1138    def __invert__(self):
1139        return Negated(self)
1140
1141    def _e(op, inv=False):
1142        """
1143        Lightweight factory which returns a method that builds an Expression
1144        consisting of the left-hand and right-hand operands, using `op`.
1145        """
1146        def inner(self, rhs):
1147            if inv:
1148                return Expression(rhs, op, self)
1149            return Expression(self, op, rhs)
1150        return inner
1151    __and__ = _e(OP.AND)
1152    __or__ = _e(OP.OR)
1153
1154    __add__ = _e(OP.ADD)
1155    __sub__ = _e(OP.SUB)
1156    __mul__ = _e(OP.MUL)
1157    __div__ = __truediv__ = _e(OP.DIV)
1158    __xor__ = _e(OP.XOR)
1159    __radd__ = _e(OP.ADD, inv=True)
1160    __rsub__ = _e(OP.SUB, inv=True)
1161    __rmul__ = _e(OP.MUL, inv=True)
1162    __rdiv__ = __rtruediv__ = _e(OP.DIV, inv=True)
1163    __rand__ = _e(OP.AND, inv=True)
1164    __ror__ = _e(OP.OR, inv=True)
1165    __rxor__ = _e(OP.XOR, inv=True)
1166
1167    def __eq__(self, rhs):
1168        op = OP.IS if rhs is None else OP.EQ
1169        return Expression(self, op, rhs)
1170    def __ne__(self, rhs):
1171        op = OP.IS_NOT if rhs is None else OP.NE
1172        return Expression(self, op, rhs)
1173
1174    __lt__ = _e(OP.LT)
1175    __le__ = _e(OP.LTE)
1176    __gt__ = _e(OP.GT)
1177    __ge__ = _e(OP.GTE)
1178    __lshift__ = _e(OP.IN)
1179    __rshift__ = _e(OP.IS)
1180    __mod__ = _e(OP.LIKE)
1181    __pow__ = _e(OP.ILIKE)
1182
1183    bin_and = _e(OP.BIN_AND)
1184    bin_or = _e(OP.BIN_OR)
1185    in_ = _e(OP.IN)
1186    not_in = _e(OP.NOT_IN)
1187    regexp = _e(OP.REGEXP)
1188
1189    # Special expressions.
1190    def is_null(self, is_null=True):
1191        op = OP.IS if is_null else OP.IS_NOT
1192        return Expression(self, op, None)
1193
1194    def _escape_like_expr(self, s, template):
1195        if s.find('_') >= 0 or s.find('%') >= 0 or s.find('\\') >= 0:
1196            s = s.replace('\\', '\\\\').replace('_', '\\_').replace('%', '\\%')
1197            return NodeList((template % s, SQL('ESCAPE'), '\\'))
1198        return template % s
1199    def contains(self, rhs):
1200        if isinstance(rhs, Node):
1201            rhs = Expression('%', OP.CONCAT,
1202                             Expression(rhs, OP.CONCAT, '%'))
1203        else:
1204            rhs = self._escape_like_expr(rhs, '%%%s%%')
1205        return Expression(self, OP.ILIKE, rhs)
1206    def startswith(self, rhs):
1207        if isinstance(rhs, Node):
1208            rhs = Expression(rhs, OP.CONCAT, '%')
1209        else:
1210            rhs = self._escape_like_expr(rhs, '%s%%')
1211        return Expression(self, OP.ILIKE, rhs)
1212    def endswith(self, rhs):
1213        if isinstance(rhs, Node):
1214            rhs = Expression('%', OP.CONCAT, rhs)
1215        else:
1216            rhs = self._escape_like_expr(rhs, '%%%s')
1217        return Expression(self, OP.ILIKE, rhs)
1218    def between(self, lo, hi):
1219        return Expression(self, OP.BETWEEN, NodeList((lo, SQL('AND'), hi)))
1220    def concat(self, rhs):
1221        return StringExpression(self, OP.CONCAT, rhs)
1222    def regexp(self, rhs):
1223        return Expression(self, OP.REGEXP, rhs)
1224    def iregexp(self, rhs):
1225        return Expression(self, OP.IREGEXP, rhs)
1226    def __getitem__(self, item):
1227        if isinstance(item, slice):
1228            if item.start is None or item.stop is None:
1229                raise ValueError('BETWEEN range must have both a start- and '
1230                                 'end-point.')
1231            return self.between(item.start, item.stop)
1232        return self == item
1233
1234    def distinct(self):
1235        return NodeList((SQL('DISTINCT'), self))
1236
1237    def collate(self, collation):
1238        return NodeList((self, SQL('COLLATE %s' % collation)))
1239
1240    def get_sort_key(self, ctx):
1241        return ()
1242
1243
1244class Column(ColumnBase):
1245    def __init__(self, source, name):
1246        self.source = source
1247        self.name = name
1248
1249    def get_sort_key(self, ctx):
1250        if ctx.scope == SCOPE_VALUES:
1251            return (self.name,)
1252        else:
1253            return self.source.get_sort_key(ctx) + (self.name,)
1254
1255    def __hash__(self):
1256        return hash((self.source, self.name))
1257
1258    def __sql__(self, ctx):
1259        if ctx.scope == SCOPE_VALUES:
1260            return ctx.sql(Entity(self.name))
1261        else:
1262            with ctx.scope_column():
1263                return ctx.sql(self.source).literal('.').sql(Entity(self.name))
1264
1265
1266class WrappedNode(ColumnBase):
1267    def __init__(self, node):
1268        self.node = node
1269        self._coerce = getattr(node, '_coerce', True)
1270        self._converter = getattr(node, '_converter', None)
1271
1272    def is_alias(self):
1273        return self.node.is_alias()
1274
1275    def unwrap(self):
1276        return self.node.unwrap()
1277
1278
1279class EntityFactory(object):
1280    __slots__ = ('node',)
1281    def __init__(self, node):
1282        self.node = node
1283    def __getattr__(self, attr):
1284        return Entity(self.node, attr)
1285
1286
1287class _DynamicEntity(object):
1288    __slots__ = ()
1289    def __get__(self, instance, instance_type=None):
1290        if instance is not None:
1291            return EntityFactory(instance._alias)  # Implements __getattr__().
1292        return self
1293
1294
1295class Alias(WrappedNode):
1296    c = _DynamicEntity()
1297
1298    def __init__(self, node, alias):
1299        super(Alias, self).__init__(node)
1300        self._alias = alias
1301
1302    def __hash__(self):
1303        return hash(self._alias)
1304
1305    def alias(self, alias=None):
1306        if alias is None:
1307            return self.node
1308        else:
1309            return Alias(self.node, alias)
1310
1311    def unalias(self):
1312        return self.node
1313
1314    def is_alias(self):
1315        return True
1316
1317    def __sql__(self, ctx):
1318        if ctx.scope == SCOPE_SOURCE:
1319            return (ctx
1320                    .sql(self.node)
1321                    .literal(' AS ')
1322                    .sql(Entity(self._alias)))
1323        else:
1324            return ctx.sql(Entity(self._alias))
1325
1326
1327class Negated(WrappedNode):
1328    def __invert__(self):
1329        return self.node
1330
1331    def __sql__(self, ctx):
1332        return ctx.literal('NOT ').sql(self.node)
1333
1334
1335class BitwiseMixin(object):
1336    def __and__(self, other):
1337        return self.bin_and(other)
1338
1339    def __or__(self, other):
1340        return self.bin_or(other)
1341
1342    def __sub__(self, other):
1343        return self.bin_and(other.bin_negated())
1344
1345    def __invert__(self):
1346        return BitwiseNegated(self)
1347
1348
1349class BitwiseNegated(BitwiseMixin, WrappedNode):
1350    def __invert__(self):
1351        return self.node
1352
1353    def __sql__(self, ctx):
1354        if ctx.state.operations:
1355            op_sql = ctx.state.operations.get(self.op, self.op)
1356        else:
1357            op_sql = self.op
1358        return ctx.literal(op_sql).sql(self.node)
1359
1360
1361class Value(ColumnBase):
1362    def __init__(self, value, converter=None, unpack=True):
1363        self.value = value
1364        self.converter = converter
1365        self.multi = unpack and isinstance(self.value, multi_types)
1366        if self.multi:
1367            self.values = []
1368            for item in self.value:
1369                if isinstance(item, Node):
1370                    self.values.append(item)
1371                else:
1372                    self.values.append(Value(item, self.converter))
1373
1374    def __sql__(self, ctx):
1375        if self.multi:
1376            # For multi-part values (e.g. lists of IDs).
1377            return ctx.sql(EnclosedNodeList(self.values))
1378
1379        return ctx.value(self.value, self.converter)
1380
1381
1382def AsIs(value):
1383    return Value(value, unpack=False)
1384
1385
1386class Cast(WrappedNode):
1387    def __init__(self, node, cast):
1388        super(Cast, self).__init__(node)
1389        self._cast = cast
1390        self._coerce = False
1391
1392    def __sql__(self, ctx):
1393        return (ctx
1394                .literal('CAST(')
1395                .sql(self.node)
1396                .literal(' AS %s)' % self._cast))
1397
1398
1399class Ordering(WrappedNode):
1400    def __init__(self, node, direction, collation=None, nulls=None):
1401        super(Ordering, self).__init__(node)
1402        self.direction = direction
1403        self.collation = collation
1404        self.nulls = nulls
1405        if nulls and nulls.lower() not in ('first', 'last'):
1406            raise ValueError('Ordering nulls= parameter must be "first" or '
1407                             '"last", got: %s' % nulls)
1408
1409    def collate(self, collation=None):
1410        return Ordering(self.node, self.direction, collation)
1411
1412    def _null_ordering_case(self, nulls):
1413        if nulls.lower() == 'last':
1414            ifnull, notnull = 1, 0
1415        elif nulls.lower() == 'first':
1416            ifnull, notnull = 0, 1
1417        else:
1418            raise ValueError('unsupported value for nulls= ordering.')
1419        return Case(None, ((self.node.is_null(), ifnull),), notnull)
1420
1421    def __sql__(self, ctx):
1422        if self.nulls and not ctx.state.nulls_ordering:
1423            ctx.sql(self._null_ordering_case(self.nulls)).literal(', ')
1424
1425        ctx.sql(self.node).literal(' %s' % self.direction)
1426        if self.collation:
1427            ctx.literal(' COLLATE %s' % self.collation)
1428        if self.nulls and ctx.state.nulls_ordering:
1429            ctx.literal(' NULLS %s' % self.nulls)
1430        return ctx
1431
1432
1433def Asc(node, collation=None, nulls=None):
1434    return Ordering(node, 'ASC', collation, nulls)
1435
1436
1437def Desc(node, collation=None, nulls=None):
1438    return Ordering(node, 'DESC', collation, nulls)
1439
1440
1441class Expression(ColumnBase):
1442    def __init__(self, lhs, op, rhs, flat=False):
1443        self.lhs = lhs
1444        self.op = op
1445        self.rhs = rhs
1446        self.flat = flat
1447
1448    def __sql__(self, ctx):
1449        overrides = {'parentheses': not self.flat, 'in_expr': True}
1450
1451        # First attempt to unwrap the node on the left-hand-side, so that we
1452        # can get at the underlying Field if one is present.
1453        node = raw_node = self.lhs
1454        if isinstance(raw_node, WrappedNode):
1455            node = raw_node.unwrap()
1456
1457        # Set up the appropriate converter if we have a field on the left side.
1458        if isinstance(node, Field) and raw_node._coerce:
1459            overrides['converter'] = node.db_value
1460            overrides['is_fk_expr'] = isinstance(node, ForeignKeyField)
1461        else:
1462            overrides['converter'] = None
1463
1464        if ctx.state.operations:
1465            op_sql = ctx.state.operations.get(self.op, self.op)
1466        else:
1467            op_sql = self.op
1468
1469        with ctx(**overrides):
1470            # Postgresql reports an error for IN/NOT IN (), so convert to
1471            # the equivalent boolean expression.
1472            op_in = self.op == OP.IN or self.op == OP.NOT_IN
1473            if op_in and ctx.as_new().parse(self.rhs)[0] == '()':
1474                return ctx.literal('0 = 1' if self.op == OP.IN else '1 = 1')
1475
1476            return (ctx
1477                    .sql(self.lhs)
1478                    .literal(' %s ' % op_sql)
1479                    .sql(self.rhs))
1480
1481
1482class StringExpression(Expression):
1483    def __add__(self, rhs):
1484        return self.concat(rhs)
1485    def __radd__(self, lhs):
1486        return StringExpression(lhs, OP.CONCAT, self)
1487
1488
1489class Entity(ColumnBase):
1490    def __init__(self, *path):
1491        self._path = [part.replace('"', '""') for part in path if part]
1492
1493    def __getattr__(self, attr):
1494        return Entity(*self._path + [attr])
1495
1496    def get_sort_key(self, ctx):
1497        return tuple(self._path)
1498
1499    def __hash__(self):
1500        return hash((self.__class__.__name__, tuple(self._path)))
1501
1502    def __sql__(self, ctx):
1503        return ctx.literal(quote(self._path, ctx.state.quote or '""'))
1504
1505
1506class SQL(ColumnBase):
1507    def __init__(self, sql, params=None):
1508        self.sql = sql
1509        self.params = params
1510
1511    def __sql__(self, ctx):
1512        ctx.literal(self.sql)
1513        if self.params:
1514            for param in self.params:
1515                ctx.value(param, False, add_param=False)
1516        return ctx
1517
1518
1519def Check(constraint, name=None):
1520    check = SQL('CHECK (%s)' % constraint)
1521    if not name:
1522        return check
1523    return NodeList((SQL('CONSTRAINT'), Entity(name), check))
1524
1525
1526class Function(ColumnBase):
1527    def __init__(self, name, arguments, coerce=True, python_value=None):
1528        self.name = name
1529        self.arguments = arguments
1530        self._filter = None
1531        self._order_by = None
1532        self._python_value = python_value
1533        if name and name.lower() in ('sum', 'count', 'cast', 'array_agg'):
1534            self._coerce = False
1535        else:
1536            self._coerce = coerce
1537
1538    def __getattr__(self, attr):
1539        def decorator(*args, **kwargs):
1540            return Function(attr, args, **kwargs)
1541        return decorator
1542
1543    @Node.copy
1544    def filter(self, where=None):
1545        self._filter = where
1546
1547    @Node.copy
1548    def order_by(self, *ordering):
1549        self._order_by = ordering
1550
1551    @Node.copy
1552    def python_value(self, func=None):
1553        self._python_value = func
1554
1555    def over(self, partition_by=None, order_by=None, start=None, end=None,
1556             frame_type=None, window=None, exclude=None):
1557        if isinstance(partition_by, Window) and window is None:
1558            window = partition_by
1559
1560        if window is not None:
1561            node = WindowAlias(window)
1562        else:
1563            node = Window(partition_by=partition_by, order_by=order_by,
1564                          start=start, end=end, frame_type=frame_type,
1565                          exclude=exclude, _inline=True)
1566        return NodeList((self, SQL('OVER'), node))
1567
1568    def __sql__(self, ctx):
1569        ctx.literal(self.name)
1570        if not len(self.arguments):
1571            ctx.literal('()')
1572        else:
1573            args = self.arguments
1574
1575            # If this is an ordered aggregate, then we will modify the last
1576            # argument to append the ORDER BY ... clause. We do this to avoid
1577            # double-wrapping any expression args in parentheses, as NodeList
1578            # has a special check (hack) in place to work around this.
1579            if self._order_by:
1580                args = list(args)
1581                args[-1] = NodeList((args[-1], SQL('ORDER BY'),
1582                                     CommaNodeList(self._order_by)))
1583
1584            with ctx(in_function=True, function_arg_count=len(self.arguments)):
1585                ctx.sql(EnclosedNodeList([
1586                    (arg if isinstance(arg, Node) else Value(arg, False))
1587                    for arg in args]))
1588
1589        if self._filter:
1590            ctx.literal(' FILTER (WHERE ').sql(self._filter).literal(')')
1591        return ctx
1592
1593
1594fn = Function(None, None)
1595
1596
1597class Window(Node):
1598    # Frame start/end and frame exclusion.
1599    CURRENT_ROW = SQL('CURRENT ROW')
1600    GROUP = SQL('GROUP')
1601    TIES = SQL('TIES')
1602    NO_OTHERS = SQL('NO OTHERS')
1603
1604    # Frame types.
1605    GROUPS = 'GROUPS'
1606    RANGE = 'RANGE'
1607    ROWS = 'ROWS'
1608
1609    def __init__(self, partition_by=None, order_by=None, start=None, end=None,
1610                 frame_type=None, extends=None, exclude=None, alias=None,
1611                 _inline=False):
1612        super(Window, self).__init__()
1613        if start is not None and not isinstance(start, SQL):
1614            start = SQL(start)
1615        if end is not None and not isinstance(end, SQL):
1616            end = SQL(end)
1617
1618        self.partition_by = ensure_tuple(partition_by)
1619        self.order_by = ensure_tuple(order_by)
1620        self.start = start
1621        self.end = end
1622        if self.start is None and self.end is not None:
1623            raise ValueError('Cannot specify WINDOW end without start.')
1624        self._alias = alias or 'w'
1625        self._inline = _inline
1626        self.frame_type = frame_type
1627        self._extends = extends
1628        self._exclude = exclude
1629
1630    def alias(self, alias=None):
1631        self._alias = alias or 'w'
1632        return self
1633
1634    @Node.copy
1635    def as_range(self):
1636        self.frame_type = Window.RANGE
1637
1638    @Node.copy
1639    def as_rows(self):
1640        self.frame_type = Window.ROWS
1641
1642    @Node.copy
1643    def as_groups(self):
1644        self.frame_type = Window.GROUPS
1645
1646    @Node.copy
1647    def extends(self, window=None):
1648        self._extends = window
1649
1650    @Node.copy
1651    def exclude(self, frame_exclusion=None):
1652        if isinstance(frame_exclusion, basestring):
1653            frame_exclusion = SQL(frame_exclusion)
1654        self._exclude = frame_exclusion
1655
1656    @staticmethod
1657    def following(value=None):
1658        if value is None:
1659            return SQL('UNBOUNDED FOLLOWING')
1660        return SQL('%d FOLLOWING' % value)
1661
1662    @staticmethod
1663    def preceding(value=None):
1664        if value is None:
1665            return SQL('UNBOUNDED PRECEDING')
1666        return SQL('%d PRECEDING' % value)
1667
1668    def __sql__(self, ctx):
1669        if ctx.scope != SCOPE_SOURCE and not self._inline:
1670            ctx.literal(self._alias)
1671            ctx.literal(' AS ')
1672
1673        with ctx(parentheses=True):
1674            parts = []
1675            if self._extends is not None:
1676                ext = self._extends
1677                if isinstance(ext, Window):
1678                    ext = SQL(ext._alias)
1679                elif isinstance(ext, basestring):
1680                    ext = SQL(ext)
1681                parts.append(ext)
1682            if self.partition_by:
1683                parts.extend((
1684                    SQL('PARTITION BY'),
1685                    CommaNodeList(self.partition_by)))
1686            if self.order_by:
1687                parts.extend((
1688                    SQL('ORDER BY'),
1689                    CommaNodeList(self.order_by)))
1690            if self.start is not None and self.end is not None:
1691                frame = self.frame_type or 'ROWS'
1692                parts.extend((
1693                    SQL('%s BETWEEN' % frame),
1694                    self.start,
1695                    SQL('AND'),
1696                    self.end))
1697            elif self.start is not None:
1698                parts.extend((SQL(self.frame_type or 'ROWS'), self.start))
1699            elif self.frame_type is not None:
1700                parts.append(SQL('%s UNBOUNDED PRECEDING' % self.frame_type))
1701            if self._exclude is not None:
1702                parts.extend((SQL('EXCLUDE'), self._exclude))
1703            ctx.sql(NodeList(parts))
1704        return ctx
1705
1706
1707class WindowAlias(Node):
1708    def __init__(self, window):
1709        self.window = window
1710
1711    def alias(self, window_alias):
1712        self.window._alias = window_alias
1713        return self
1714
1715    def __sql__(self, ctx):
1716        return ctx.literal(self.window._alias or 'w')
1717
1718
1719class ForUpdate(Node):
1720    def __init__(self, expr, of=None, nowait=None):
1721        expr = 'FOR UPDATE' if expr is True else expr
1722        if expr.lower().endswith('nowait'):
1723            expr = expr[:-7]  # Strip off the "nowait" bit.
1724            nowait = True
1725
1726        self._expr = expr
1727        if of is not None and not isinstance(of, (list, set, tuple)):
1728            of = (of,)
1729        self._of = of
1730        self._nowait = nowait
1731
1732    def __sql__(self, ctx):
1733        ctx.literal(self._expr)
1734        if self._of is not None:
1735            ctx.literal(' OF ').sql(CommaNodeList(self._of))
1736        if self._nowait:
1737            ctx.literal(' NOWAIT')
1738        return ctx
1739
1740
1741def Case(predicate, expression_tuples, default=None):
1742    clauses = [SQL('CASE')]
1743    if predicate is not None:
1744        clauses.append(predicate)
1745    for expr, value in expression_tuples:
1746        clauses.extend((SQL('WHEN'), expr, SQL('THEN'), value))
1747    if default is not None:
1748        clauses.extend((SQL('ELSE'), default))
1749    clauses.append(SQL('END'))
1750    return NodeList(clauses)
1751
1752
1753class NodeList(ColumnBase):
1754    def __init__(self, nodes, glue=' ', parens=False):
1755        self.nodes = nodes
1756        self.glue = glue
1757        self.parens = parens
1758        if parens and len(self.nodes) == 1 and \
1759           isinstance(self.nodes[0], Expression) and \
1760           not self.nodes[0].flat:
1761            # Hack to avoid double-parentheses.
1762            self.nodes = (self.nodes[0].clone(),)
1763            self.nodes[0].flat = True
1764
1765    def __sql__(self, ctx):
1766        n_nodes = len(self.nodes)
1767        if n_nodes == 0:
1768            return ctx.literal('()') if self.parens else ctx
1769        with ctx(parentheses=self.parens):
1770            for i in range(n_nodes - 1):
1771                ctx.sql(self.nodes[i])
1772                ctx.literal(self.glue)
1773            ctx.sql(self.nodes[n_nodes - 1])
1774        return ctx
1775
1776
1777def CommaNodeList(nodes):
1778    return NodeList(nodes, ', ')
1779
1780
1781def EnclosedNodeList(nodes):
1782    return NodeList(nodes, ', ', True)
1783
1784
1785class _Namespace(Node):
1786    __slots__ = ('_name',)
1787    def __init__(self, name):
1788        self._name = name
1789    def __getattr__(self, attr):
1790        return NamespaceAttribute(self, attr)
1791    __getitem__ = __getattr__
1792
1793class NamespaceAttribute(ColumnBase):
1794    def __init__(self, namespace, attribute):
1795        self._namespace = namespace
1796        self._attribute = attribute
1797
1798    def __sql__(self, ctx):
1799        return (ctx
1800                .literal(self._namespace._name + '.')
1801                .sql(Entity(self._attribute)))
1802
1803EXCLUDED = _Namespace('EXCLUDED')
1804
1805
1806class DQ(ColumnBase):
1807    def __init__(self, **query):
1808        super(DQ, self).__init__()
1809        self.query = query
1810        self._negated = False
1811
1812    @Node.copy
1813    def __invert__(self):
1814        self._negated = not self._negated
1815
1816    def clone(self):
1817        node = DQ(**self.query)
1818        node._negated = self._negated
1819        return node
1820
1821#: Represent a row tuple.
1822Tuple = lambda *a: EnclosedNodeList(a)
1823
1824
1825class QualifiedNames(WrappedNode):
1826    def __sql__(self, ctx):
1827        with ctx.scope_column():
1828            return ctx.sql(self.node)
1829
1830
1831def qualify_names(node):
1832    # Search a node heirarchy to ensure that any column-like objects are
1833    # referenced using fully-qualified names.
1834    if isinstance(node, Expression):
1835        return node.__class__(qualify_names(node.lhs), node.op,
1836                              qualify_names(node.rhs), node.flat)
1837    elif isinstance(node, ColumnBase):
1838        return QualifiedNames(node)
1839    return node
1840
1841
1842class OnConflict(Node):
1843    def __init__(self, action=None, update=None, preserve=None, where=None,
1844                 conflict_target=None, conflict_where=None,
1845                 conflict_constraint=None):
1846        self._action = action
1847        self._update = update
1848        self._preserve = ensure_tuple(preserve)
1849        self._where = where
1850        if conflict_target is not None and conflict_constraint is not None:
1851            raise ValueError('only one of "conflict_target" and '
1852                             '"conflict_constraint" may be specified.')
1853        self._conflict_target = ensure_tuple(conflict_target)
1854        self._conflict_where = conflict_where
1855        self._conflict_constraint = conflict_constraint
1856
1857    def get_conflict_statement(self, ctx, query):
1858        return ctx.state.conflict_statement(self, query)
1859
1860    def get_conflict_update(self, ctx, query):
1861        return ctx.state.conflict_update(self, query)
1862
1863    @Node.copy
1864    def preserve(self, *columns):
1865        self._preserve = columns
1866
1867    @Node.copy
1868    def update(self, _data=None, **kwargs):
1869        if _data and kwargs and not isinstance(_data, dict):
1870            raise ValueError('Cannot mix data with keyword arguments in the '
1871                             'OnConflict update method.')
1872        _data = _data or {}
1873        if kwargs:
1874            _data.update(kwargs)
1875        self._update = _data
1876
1877    @Node.copy
1878    def where(self, *expressions):
1879        if self._where is not None:
1880            expressions = (self._where,) + expressions
1881        self._where = reduce(operator.and_, expressions)
1882
1883    @Node.copy
1884    def conflict_target(self, *constraints):
1885        self._conflict_constraint = None
1886        self._conflict_target = constraints
1887
1888    @Node.copy
1889    def conflict_where(self, *expressions):
1890        if self._conflict_where is not None:
1891            expressions = (self._conflict_where,) + expressions
1892        self._conflict_where = reduce(operator.and_, expressions)
1893
1894    @Node.copy
1895    def conflict_constraint(self, constraint):
1896        self._conflict_constraint = constraint
1897        self._conflict_target = None
1898
1899
1900def database_required(method):
1901    @wraps(method)
1902    def inner(self, database=None, *args, **kwargs):
1903        database = self._database if database is None else database
1904        if not database:
1905            raise InterfaceError('Query must be bound to a database in order '
1906                                 'to call "%s".' % method.__name__)
1907        return method(self, database, *args, **kwargs)
1908    return inner
1909
1910# BASE QUERY INTERFACE.
1911
1912class BaseQuery(Node):
1913    default_row_type = ROW.DICT
1914
1915    def __init__(self, _database=None, **kwargs):
1916        self._database = _database
1917        self._cursor_wrapper = None
1918        self._row_type = None
1919        self._constructor = None
1920        super(BaseQuery, self).__init__(**kwargs)
1921
1922    def bind(self, database=None):
1923        self._database = database
1924        return self
1925
1926    def clone(self):
1927        query = super(BaseQuery, self).clone()
1928        query._cursor_wrapper = None
1929        return query
1930
1931    @Node.copy
1932    def dicts(self, as_dict=True):
1933        self._row_type = ROW.DICT if as_dict else None
1934        return self
1935
1936    @Node.copy
1937    def tuples(self, as_tuple=True):
1938        self._row_type = ROW.TUPLE if as_tuple else None
1939        return self
1940
1941    @Node.copy
1942    def namedtuples(self, as_namedtuple=True):
1943        self._row_type = ROW.NAMED_TUPLE if as_namedtuple else None
1944        return self
1945
1946    @Node.copy
1947    def objects(self, constructor=None):
1948        self._row_type = ROW.CONSTRUCTOR if constructor else None
1949        self._constructor = constructor
1950        return self
1951
1952    def _get_cursor_wrapper(self, cursor):
1953        row_type = self._row_type or self.default_row_type
1954
1955        if row_type == ROW.DICT:
1956            return DictCursorWrapper(cursor)
1957        elif row_type == ROW.TUPLE:
1958            return CursorWrapper(cursor)
1959        elif row_type == ROW.NAMED_TUPLE:
1960            return NamedTupleCursorWrapper(cursor)
1961        elif row_type == ROW.CONSTRUCTOR:
1962            return ObjectCursorWrapper(cursor, self._constructor)
1963        else:
1964            raise ValueError('Unrecognized row type: "%s".' % row_type)
1965
1966    def __sql__(self, ctx):
1967        raise NotImplementedError
1968
1969    def sql(self):
1970        if self._database:
1971            context = self._database.get_sql_context()
1972        else:
1973            context = Context()
1974        return context.parse(self)
1975
1976    @database_required
1977    def execute(self, database):
1978        return self._execute(database)
1979
1980    def _execute(self, database):
1981        raise NotImplementedError
1982
1983    def iterator(self, database=None):
1984        return iter(self.execute(database).iterator())
1985
1986    def _ensure_execution(self):
1987        if not self._cursor_wrapper:
1988            if not self._database:
1989                raise ValueError('Query has not been executed.')
1990            self.execute()
1991
1992    def __iter__(self):
1993        self._ensure_execution()
1994        return iter(self._cursor_wrapper)
1995
1996    def __getitem__(self, value):
1997        self._ensure_execution()
1998        if isinstance(value, slice):
1999            index = value.stop
2000        else:
2001            index = value
2002        if index is not None:
2003            index = index + 1 if index >= 0 else 0
2004        self._cursor_wrapper.fill_cache(index)
2005        return self._cursor_wrapper.row_cache[value]
2006
2007    def __len__(self):
2008        self._ensure_execution()
2009        return len(self._cursor_wrapper)
2010
2011    def __str__(self):
2012        return query_to_string(self)
2013
2014
2015class RawQuery(BaseQuery):
2016    def __init__(self, sql=None, params=None, **kwargs):
2017        super(RawQuery, self).__init__(**kwargs)
2018        self._sql = sql
2019        self._params = params
2020
2021    def __sql__(self, ctx):
2022        ctx.literal(self._sql)
2023        if self._params:
2024            for param in self._params:
2025                ctx.value(param, add_param=False)
2026        return ctx
2027
2028    def _execute(self, database):
2029        if self._cursor_wrapper is None:
2030            cursor = database.execute(self)
2031            self._cursor_wrapper = self._get_cursor_wrapper(cursor)
2032        return self._cursor_wrapper
2033
2034
2035class Query(BaseQuery):
2036    def __init__(self, where=None, order_by=None, limit=None, offset=None,
2037                 **kwargs):
2038        super(Query, self).__init__(**kwargs)
2039        self._where = where
2040        self._order_by = order_by
2041        self._limit = limit
2042        self._offset = offset
2043
2044        self._cte_list = None
2045
2046    @Node.copy
2047    def with_cte(self, *cte_list):
2048        self._cte_list = cte_list
2049
2050    @Node.copy
2051    def where(self, *expressions):
2052        if self._where is not None:
2053            expressions = (self._where,) + expressions
2054        self._where = reduce(operator.and_, expressions)
2055
2056    @Node.copy
2057    def orwhere(self, *expressions):
2058        if self._where is not None:
2059            expressions = (self._where,) + expressions
2060        self._where = reduce(operator.or_, expressions)
2061
2062    @Node.copy
2063    def order_by(self, *values):
2064        self._order_by = values
2065
2066    @Node.copy
2067    def order_by_extend(self, *values):
2068        self._order_by = ((self._order_by or ()) + values) or None
2069
2070    @Node.copy
2071    def limit(self, value=None):
2072        self._limit = value
2073
2074    @Node.copy
2075    def offset(self, value=None):
2076        self._offset = value
2077
2078    @Node.copy
2079    def paginate(self, page, paginate_by=20):
2080        if page > 0:
2081            page -= 1
2082        self._limit = paginate_by
2083        self._offset = page * paginate_by
2084
2085    def _apply_ordering(self, ctx):
2086        if self._order_by:
2087            (ctx
2088             .literal(' ORDER BY ')
2089             .sql(CommaNodeList(self._order_by)))
2090        if self._limit is not None or (self._offset is not None and
2091                                       ctx.state.limit_max):
2092            limit = ctx.state.limit_max if self._limit is None else self._limit
2093            ctx.literal(' LIMIT ').sql(limit)
2094        if self._offset is not None:
2095            ctx.literal(' OFFSET ').sql(self._offset)
2096        return ctx
2097
2098    def __sql__(self, ctx):
2099        if self._cte_list:
2100            # The CTE scope is only used at the very beginning of the query,
2101            # when we are describing the various CTEs we will be using.
2102            recursive = any(cte._recursive for cte in self._cte_list)
2103
2104            # Explicitly disable the "subquery" flag here, so as to avoid
2105            # unnecessary parentheses around subsequent selects.
2106            with ctx.scope_cte(subquery=False):
2107                (ctx
2108                 .literal('WITH RECURSIVE ' if recursive else 'WITH ')
2109                 .sql(CommaNodeList(self._cte_list))
2110                 .literal(' '))
2111        return ctx
2112
2113
2114def __compound_select__(operation, inverted=False):
2115    def method(self, other):
2116        if inverted:
2117            self, other = other, self
2118        return CompoundSelectQuery(self, operation, other)
2119    return method
2120
2121
2122class SelectQuery(Query):
2123    union_all = __add__ = __compound_select__('UNION ALL')
2124    union = __or__ = __compound_select__('UNION')
2125    intersect = __and__ = __compound_select__('INTERSECT')
2126    except_ = __sub__ = __compound_select__('EXCEPT')
2127    __radd__ = __compound_select__('UNION ALL', inverted=True)
2128    __ror__ = __compound_select__('UNION', inverted=True)
2129    __rand__ = __compound_select__('INTERSECT', inverted=True)
2130    __rsub__ = __compound_select__('EXCEPT', inverted=True)
2131
2132    def select_from(self, *columns):
2133        if not columns:
2134            raise ValueError('select_from() must specify one or more columns.')
2135
2136        query = (Select((self,), columns)
2137                 .bind(self._database))
2138        if getattr(self, 'model', None) is not None:
2139            # Bind to the sub-select's model type, if defined.
2140            query = query.objects(self.model)
2141        return query
2142
2143
2144class SelectBase(_HashableSource, Source, SelectQuery):
2145    def _get_hash(self):
2146        return hash((self.__class__, self._alias or id(self)))
2147
2148    def _execute(self, database):
2149        if self._cursor_wrapper is None:
2150            cursor = database.execute(self)
2151            self._cursor_wrapper = self._get_cursor_wrapper(cursor)
2152        return self._cursor_wrapper
2153
2154    @database_required
2155    def peek(self, database, n=1):
2156        rows = self.execute(database)[:n]
2157        if rows:
2158            return rows[0] if n == 1 else rows
2159
2160    @database_required
2161    def first(self, database, n=1):
2162        if self._limit != n:
2163            self._limit = n
2164            self._cursor_wrapper = None
2165        return self.peek(database, n=n)
2166
2167    @database_required
2168    def scalar(self, database, as_tuple=False):
2169        row = self.tuples().peek(database)
2170        return row[0] if row and not as_tuple else row
2171
2172    @database_required
2173    def count(self, database, clear_limit=False):
2174        clone = self.order_by().alias('_wrapped')
2175        if clear_limit:
2176            clone._limit = clone._offset = None
2177        try:
2178            if clone._having is None and clone._group_by is None and \
2179               clone._windows is None and clone._distinct is None and \
2180               clone._simple_distinct is not True:
2181                clone = clone.select(SQL('1'))
2182        except AttributeError:
2183            pass
2184        return Select([clone], [fn.COUNT(SQL('1'))]).scalar(database)
2185
2186    @database_required
2187    def exists(self, database):
2188        clone = self.columns(SQL('1'))
2189        clone._limit = 1
2190        clone._offset = None
2191        return bool(clone.scalar())
2192
2193    @database_required
2194    def get(self, database):
2195        self._cursor_wrapper = None
2196        try:
2197            return self.execute(database)[0]
2198        except IndexError:
2199            pass
2200
2201
2202# QUERY IMPLEMENTATIONS.
2203
2204
2205class CompoundSelectQuery(SelectBase):
2206    def __init__(self, lhs, op, rhs):
2207        super(CompoundSelectQuery, self).__init__()
2208        self.lhs = lhs
2209        self.op = op
2210        self.rhs = rhs
2211
2212    @property
2213    def _returning(self):
2214        return self.lhs._returning
2215
2216    @database_required
2217    def exists(self, database):
2218        query = Select((self.limit(1),), (SQL('1'),)).bind(database)
2219        return bool(query.scalar())
2220
2221    def _get_query_key(self):
2222        return (self.lhs.get_query_key(), self.rhs.get_query_key())
2223
2224    def _wrap_parens(self, ctx, subq):
2225        csq_setting = ctx.state.compound_select_parentheses
2226
2227        if not csq_setting or csq_setting == CSQ_PARENTHESES_NEVER:
2228            return False
2229        elif csq_setting == CSQ_PARENTHESES_ALWAYS:
2230            return True
2231        elif csq_setting == CSQ_PARENTHESES_UNNESTED:
2232            if ctx.state.in_expr or ctx.state.in_function:
2233                # If this compound select query is being used inside an
2234                # expression, e.g., an IN or EXISTS().
2235                return False
2236
2237            # If the query on the left or right is itself a compound select
2238            # query, then we do not apply parentheses. However, if it is a
2239            # regular SELECT query, we will apply parentheses.
2240            return not isinstance(subq, CompoundSelectQuery)
2241
2242    def __sql__(self, ctx):
2243        if ctx.scope == SCOPE_COLUMN:
2244            return self.apply_column(ctx)
2245
2246        # Call parent method to handle any CTEs.
2247        super(CompoundSelectQuery, self).__sql__(ctx)
2248
2249        outer_parens = ctx.subquery or (ctx.scope == SCOPE_SOURCE)
2250        with ctx(parentheses=outer_parens):
2251            # Should the left-hand query be wrapped in parentheses?
2252            lhs_parens = self._wrap_parens(ctx, self.lhs)
2253            with ctx.scope_normal(parentheses=lhs_parens, subquery=False):
2254                ctx.sql(self.lhs)
2255            ctx.literal(' %s ' % self.op)
2256            with ctx.push_alias():
2257                # Should the right-hand query be wrapped in parentheses?
2258                rhs_parens = self._wrap_parens(ctx, self.rhs)
2259                with ctx.scope_normal(parentheses=rhs_parens, subquery=False):
2260                    ctx.sql(self.rhs)
2261
2262            # Apply ORDER BY, LIMIT, OFFSET. We use the "values" scope so that
2263            # entity names are not fully-qualified. This is a bit of a hack, as
2264            # we're relying on the logic in Column.__sql__() to not fully
2265            # qualify column names.
2266            with ctx.scope_values():
2267                self._apply_ordering(ctx)
2268
2269        return self.apply_alias(ctx)
2270
2271
2272class Select(SelectBase):
2273    def __init__(self, from_list=None, columns=None, group_by=None,
2274                 having=None, distinct=None, windows=None, for_update=None,
2275                 for_update_of=None, nowait=None, lateral=None, **kwargs):
2276        super(Select, self).__init__(**kwargs)
2277        self._from_list = (list(from_list) if isinstance(from_list, tuple)
2278                           else from_list) or []
2279        self._returning = columns
2280        self._group_by = group_by
2281        self._having = having
2282        self._windows = None
2283        self._for_update = for_update  # XXX: consider reorganizing.
2284        self._for_update_of = for_update_of
2285        self._for_update_nowait = nowait
2286        self._lateral = lateral
2287
2288        self._distinct = self._simple_distinct = None
2289        if distinct:
2290            if isinstance(distinct, bool):
2291                self._simple_distinct = distinct
2292            else:
2293                self._distinct = distinct
2294
2295        self._cursor_wrapper = None
2296
2297    def clone(self):
2298        clone = super(Select, self).clone()
2299        if clone._from_list:
2300            clone._from_list = list(clone._from_list)
2301        return clone
2302
2303    @Node.copy
2304    def columns(self, *columns, **kwargs):
2305        self._returning = columns
2306    select = columns
2307
2308    @Node.copy
2309    def select_extend(self, *columns):
2310        self._returning = tuple(self._returning) + columns
2311
2312    @Node.copy
2313    def from_(self, *sources):
2314        self._from_list = list(sources)
2315
2316    @Node.copy
2317    def join(self, dest, join_type=JOIN.INNER, on=None):
2318        if not self._from_list:
2319            raise ValueError('No sources to join on.')
2320        item = self._from_list.pop()
2321        self._from_list.append(Join(item, dest, join_type, on))
2322
2323    @Node.copy
2324    def group_by(self, *columns):
2325        grouping = []
2326        for column in columns:
2327            if isinstance(column, Table):
2328                if not column._columns:
2329                    raise ValueError('Cannot pass a table to group_by() that '
2330                                     'does not have columns explicitly '
2331                                     'declared.')
2332                grouping.extend([getattr(column, col_name)
2333                                 for col_name in column._columns])
2334            else:
2335                grouping.append(column)
2336        self._group_by = grouping
2337
2338    def group_by_extend(self, *values):
2339        """@Node.copy used from group_by() call"""
2340        group_by = tuple(self._group_by or ()) + values
2341        return self.group_by(*group_by)
2342
2343    @Node.copy
2344    def having(self, *expressions):
2345        if self._having is not None:
2346            expressions = (self._having,) + expressions
2347        self._having = reduce(operator.and_, expressions)
2348
2349    @Node.copy
2350    def distinct(self, *columns):
2351        if len(columns) == 1 and (columns[0] is True or columns[0] is False):
2352            self._simple_distinct = columns[0]
2353        else:
2354            self._simple_distinct = False
2355            self._distinct = columns
2356
2357    @Node.copy
2358    def window(self, *windows):
2359        self._windows = windows if windows else None
2360
2361    @Node.copy
2362    def for_update(self, for_update=True, of=None, nowait=None):
2363        if not for_update and (of is not None or nowait):
2364            for_update = True
2365        self._for_update = for_update
2366        self._for_update_of = of
2367        self._for_update_nowait = nowait
2368
2369    @Node.copy
2370    def lateral(self, lateral=True):
2371        self._lateral = lateral
2372
2373    def _get_query_key(self):
2374        return self._alias
2375
2376    def __sql_selection__(self, ctx, is_subquery=False):
2377        return ctx.sql(CommaNodeList(self._returning))
2378
2379    def __sql__(self, ctx):
2380        if ctx.scope == SCOPE_COLUMN:
2381            return self.apply_column(ctx)
2382
2383        if self._lateral and ctx.scope == SCOPE_SOURCE:
2384            ctx.literal('LATERAL ')
2385
2386        is_subquery = ctx.subquery
2387        state = {
2388            'converter': None,
2389            'in_function': False,
2390            'parentheses': is_subquery or (ctx.scope == SCOPE_SOURCE),
2391            'subquery': True,
2392        }
2393        if ctx.state.in_function and ctx.state.function_arg_count == 1:
2394            state['parentheses'] = False
2395
2396        with ctx.scope_normal(**state):
2397            # Defer calling parent SQL until here. This ensures that any CTEs
2398            # for this query will be properly nested if this query is a
2399            # sub-select or is used in an expression. See GH#1809 for example.
2400            super(Select, self).__sql__(ctx)
2401
2402            ctx.literal('SELECT ')
2403            if self._simple_distinct or self._distinct is not None:
2404                ctx.literal('DISTINCT ')
2405                if self._distinct:
2406                    (ctx
2407                     .literal('ON ')
2408                     .sql(EnclosedNodeList(self._distinct))
2409                     .literal(' '))
2410
2411            with ctx.scope_source():
2412                ctx = self.__sql_selection__(ctx, is_subquery)
2413
2414            if self._from_list:
2415                with ctx.scope_source(parentheses=False):
2416                    ctx.literal(' FROM ').sql(CommaNodeList(self._from_list))
2417
2418            if self._where is not None:
2419                ctx.literal(' WHERE ').sql(self._where)
2420
2421            if self._group_by:
2422                ctx.literal(' GROUP BY ').sql(CommaNodeList(self._group_by))
2423
2424            if self._having is not None:
2425                ctx.literal(' HAVING ').sql(self._having)
2426
2427            if self._windows is not None:
2428                ctx.literal(' WINDOW ')
2429                ctx.sql(CommaNodeList(self._windows))
2430
2431            # Apply ORDER BY, LIMIT, OFFSET.
2432            self._apply_ordering(ctx)
2433
2434            if self._for_update:
2435                if not ctx.state.for_update:
2436                    raise ValueError('FOR UPDATE specified but not supported '
2437                                     'by database.')
2438                ctx.literal(' ')
2439                ctx.sql(ForUpdate(self._for_update, self._for_update_of,
2440                                  self._for_update_nowait))
2441
2442        # If the subquery is inside a function -or- we are evaluating a
2443        # subquery on either side of an expression w/o an explicit alias, do
2444        # not generate an alias + AS clause.
2445        if ctx.state.in_function or (ctx.state.in_expr and
2446                                     self._alias is None):
2447            return ctx
2448
2449        return self.apply_alias(ctx)
2450
2451
2452class _WriteQuery(Query):
2453    def __init__(self, table, returning=None, **kwargs):
2454        self.table = table
2455        self._returning = returning
2456        self._return_cursor = True if returning else False
2457        super(_WriteQuery, self).__init__(**kwargs)
2458
2459    @Node.copy
2460    def returning(self, *returning):
2461        self._returning = returning
2462        self._return_cursor = True if returning else False
2463
2464    def apply_returning(self, ctx):
2465        if self._returning:
2466            with ctx.scope_source():
2467                ctx.literal(' RETURNING ').sql(CommaNodeList(self._returning))
2468        return ctx
2469
2470    def _execute(self, database):
2471        if self._returning:
2472            cursor = self.execute_returning(database)
2473        else:
2474            cursor = database.execute(self)
2475        return self.handle_result(database, cursor)
2476
2477    def execute_returning(self, database):
2478        if self._cursor_wrapper is None:
2479            cursor = database.execute(self)
2480            self._cursor_wrapper = self._get_cursor_wrapper(cursor)
2481        return self._cursor_wrapper
2482
2483    def handle_result(self, database, cursor):
2484        if self._return_cursor:
2485            return cursor
2486        return database.rows_affected(cursor)
2487
2488    def _set_table_alias(self, ctx):
2489        ctx.alias_manager[self.table] = self.table.__name__
2490
2491    def __sql__(self, ctx):
2492        super(_WriteQuery, self).__sql__(ctx)
2493        # We explicitly set the table alias to the table's name, which ensures
2494        # that if a sub-select references a column on the outer table, we won't
2495        # assign it a new alias (e.g. t2) but will refer to it as table.column.
2496        self._set_table_alias(ctx)
2497        return ctx
2498
2499
2500class Update(_WriteQuery):
2501    def __init__(self, table, update=None, **kwargs):
2502        super(Update, self).__init__(table, **kwargs)
2503        self._update = update
2504        self._from = None
2505
2506    @Node.copy
2507    def from_(self, *sources):
2508        self._from = sources
2509
2510    def __sql__(self, ctx):
2511        super(Update, self).__sql__(ctx)
2512
2513        with ctx.scope_values(subquery=True):
2514            ctx.literal('UPDATE ')
2515
2516            expressions = []
2517            for k, v in sorted(self._update.items(), key=ctx.column_sort_key):
2518                if not isinstance(v, Node):
2519                    if isinstance(k, Field):
2520                        v = k.to_value(v)
2521                    else:
2522                        v = Value(v, unpack=False)
2523                if not isinstance(v, Value):
2524                    v = qualify_names(v)
2525                expressions.append(NodeList((k, SQL('='), v)))
2526
2527            (ctx
2528             .sql(self.table)
2529             .literal(' SET ')
2530             .sql(CommaNodeList(expressions)))
2531
2532            if self._from:
2533                with ctx.scope_source(parentheses=False):
2534                    ctx.literal(' FROM ').sql(CommaNodeList(self._from))
2535
2536            if self._where:
2537                with ctx.scope_normal():
2538                    ctx.literal(' WHERE ').sql(self._where)
2539            self._apply_ordering(ctx)
2540            return self.apply_returning(ctx)
2541
2542
2543class Insert(_WriteQuery):
2544    SIMPLE = 0
2545    QUERY = 1
2546    MULTI = 2
2547    class DefaultValuesException(Exception): pass
2548
2549    def __init__(self, table, insert=None, columns=None, on_conflict=None,
2550                 **kwargs):
2551        super(Insert, self).__init__(table, **kwargs)
2552        self._insert = insert
2553        self._columns = columns
2554        self._on_conflict = on_conflict
2555        self._query_type = None
2556
2557    def where(self, *expressions):
2558        raise NotImplementedError('INSERT queries cannot have a WHERE clause.')
2559
2560    @Node.copy
2561    def on_conflict_ignore(self, ignore=True):
2562        self._on_conflict = OnConflict('IGNORE') if ignore else None
2563
2564    @Node.copy
2565    def on_conflict_replace(self, replace=True):
2566        self._on_conflict = OnConflict('REPLACE') if replace else None
2567
2568    @Node.copy
2569    def on_conflict(self, *args, **kwargs):
2570        self._on_conflict = (OnConflict(*args, **kwargs) if (args or kwargs)
2571                             else None)
2572
2573    def _simple_insert(self, ctx):
2574        if not self._insert:
2575            raise self.DefaultValuesException('Error: no data to insert.')
2576        return self._generate_insert((self._insert,), ctx)
2577
2578    def get_default_data(self):
2579        return {}
2580
2581    def get_default_columns(self):
2582        if self.table._columns:
2583            return [getattr(self.table, col) for col in self.table._columns
2584                    if col != self.table._primary_key]
2585
2586    def _generate_insert(self, insert, ctx):
2587        rows_iter = iter(insert)
2588        columns = self._columns
2589
2590        # Load and organize column defaults (if provided).
2591        defaults = self.get_default_data()
2592
2593        # First figure out what columns are being inserted (if they weren't
2594        # specified explicitly). Resulting columns are normalized and ordered.
2595        if not columns:
2596            try:
2597                row = next(rows_iter)
2598            except StopIteration:
2599                raise self.DefaultValuesException('Error: no rows to insert.')
2600
2601            if not isinstance(row, Mapping):
2602                columns = self.get_default_columns()
2603                if columns is None:
2604                    raise ValueError('Bulk insert must specify columns.')
2605            else:
2606                # Infer column names from the dict of data being inserted.
2607                accum = []
2608                for column in row:
2609                    if isinstance(column, basestring):
2610                        column = getattr(self.table, column)
2611                    accum.append(column)
2612
2613                # Add any columns present in the default data that are not
2614                # accounted for by the dictionary of row data.
2615                column_set = set(accum)
2616                for col in (set(defaults) - column_set):
2617                    accum.append(col)
2618
2619                columns = sorted(accum, key=lambda obj: obj.get_sort_key(ctx))
2620            rows_iter = itertools.chain(iter((row,)), rows_iter)
2621        else:
2622            clean_columns = []
2623            seen = set()
2624            for column in columns:
2625                if isinstance(column, basestring):
2626                    column_obj = getattr(self.table, column)
2627                else:
2628                    column_obj = column
2629                clean_columns.append(column_obj)
2630                seen.add(column_obj)
2631
2632            columns = clean_columns
2633            for col in sorted(defaults, key=lambda obj: obj.get_sort_key(ctx)):
2634                if col not in seen:
2635                    columns.append(col)
2636
2637        nullable_columns = set()
2638        value_lookups = {}
2639        for column in columns:
2640            lookups = [column, column.name]
2641            if isinstance(column, Field):
2642                if column.name != column.column_name:
2643                    lookups.append(column.column_name)
2644                if column.null:
2645                    nullable_columns.add(column)
2646            value_lookups[column] = lookups
2647
2648        ctx.sql(EnclosedNodeList(columns)).literal(' VALUES ')
2649        columns_converters = [
2650            (column, column.db_value if isinstance(column, Field) else None)
2651            for column in columns]
2652
2653        all_values = []
2654        for row in rows_iter:
2655            values = []
2656            is_dict = isinstance(row, Mapping)
2657            for i, (column, converter) in enumerate(columns_converters):
2658                try:
2659                    if is_dict:
2660                        # The logic is a bit convoluted, but in order to be
2661                        # flexible in what we accept (dict keyed by
2662                        # column/field, field name, or underlying column name),
2663                        # we try accessing the row data dict using each
2664                        # possible key. If no match is found, throw an error.
2665                        for lookup in value_lookups[column]:
2666                            try:
2667                                val = row[lookup]
2668                            except KeyError: pass
2669                            else: break
2670                        else:
2671                            raise KeyError
2672                    else:
2673                        val = row[i]
2674                except (KeyError, IndexError):
2675                    if column in defaults:
2676                        val = defaults[column]
2677                        if callable_(val):
2678                            val = val()
2679                    elif column in nullable_columns:
2680                        val = None
2681                    else:
2682                        raise ValueError('Missing value for %s.' % column.name)
2683
2684                if not isinstance(val, Node):
2685                    val = Value(val, converter=converter, unpack=False)
2686                values.append(val)
2687
2688            all_values.append(EnclosedNodeList(values))
2689
2690        if not all_values:
2691            raise self.DefaultValuesException('Error: no data to insert.')
2692
2693        with ctx.scope_values(subquery=True):
2694            return ctx.sql(CommaNodeList(all_values))
2695
2696    def _query_insert(self, ctx):
2697        return (ctx
2698                .sql(EnclosedNodeList(self._columns))
2699                .literal(' ')
2700                .sql(self._insert))
2701
2702    def _default_values(self, ctx):
2703        if not self._database:
2704            return ctx.literal('DEFAULT VALUES')
2705        return self._database.default_values_insert(ctx)
2706
2707    def __sql__(self, ctx):
2708        super(Insert, self).__sql__(ctx)
2709        with ctx.scope_values():
2710            stmt = None
2711            if self._on_conflict is not None:
2712                stmt = self._on_conflict.get_conflict_statement(ctx, self)
2713
2714            (ctx
2715             .sql(stmt or SQL('INSERT'))
2716             .literal(' INTO ')
2717             .sql(self.table)
2718             .literal(' '))
2719
2720            if isinstance(self._insert, Mapping) and not self._columns:
2721                try:
2722                    self._simple_insert(ctx)
2723                except self.DefaultValuesException:
2724                    self._default_values(ctx)
2725                self._query_type = Insert.SIMPLE
2726            elif isinstance(self._insert, (SelectQuery, SQL)):
2727                self._query_insert(ctx)
2728                self._query_type = Insert.QUERY
2729            else:
2730                self._generate_insert(self._insert, ctx)
2731                self._query_type = Insert.MULTI
2732
2733            if self._on_conflict is not None:
2734                update = self._on_conflict.get_conflict_update(ctx, self)
2735                if update is not None:
2736                    ctx.literal(' ').sql(update)
2737
2738            return self.apply_returning(ctx)
2739
2740    def _execute(self, database):
2741        if self._returning is None and database.returning_clause \
2742           and self.table._primary_key:
2743            self._returning = (self.table._primary_key,)
2744        try:
2745            return super(Insert, self)._execute(database)
2746        except self.DefaultValuesException:
2747            pass
2748
2749    def handle_result(self, database, cursor):
2750        if self._return_cursor:
2751            return cursor
2752        if self._query_type != Insert.SIMPLE and not self._returning:
2753            return database.rows_affected(cursor)
2754        return database.last_insert_id(cursor, self._query_type)
2755
2756
2757class Delete(_WriteQuery):
2758    def __sql__(self, ctx):
2759        super(Delete, self).__sql__(ctx)
2760
2761        with ctx.scope_values(subquery=True):
2762            ctx.literal('DELETE FROM ').sql(self.table)
2763            if self._where is not None:
2764                with ctx.scope_normal():
2765                    ctx.literal(' WHERE ').sql(self._where)
2766
2767            self._apply_ordering(ctx)
2768            return self.apply_returning(ctx)
2769
2770
2771class Index(Node):
2772    def __init__(self, name, table, expressions, unique=False, safe=False,
2773                 where=None, using=None):
2774        self._name = name
2775        self._table = Entity(table) if not isinstance(table, Table) else table
2776        self._expressions = expressions
2777        self._where = where
2778        self._unique = unique
2779        self._safe = safe
2780        self._using = using
2781
2782    @Node.copy
2783    def safe(self, _safe=True):
2784        self._safe = _safe
2785
2786    @Node.copy
2787    def where(self, *expressions):
2788        if self._where is not None:
2789            expressions = (self._where,) + expressions
2790        self._where = reduce(operator.and_, expressions)
2791
2792    @Node.copy
2793    def using(self, _using=None):
2794        self._using = _using
2795
2796    def __sql__(self, ctx):
2797        statement = 'CREATE UNIQUE INDEX ' if self._unique else 'CREATE INDEX '
2798        with ctx.scope_values(subquery=True):
2799            ctx.literal(statement)
2800            if self._safe:
2801                ctx.literal('IF NOT EXISTS ')
2802
2803            # Sqlite uses CREATE INDEX <schema>.<name> ON <table>, whereas most
2804            # others use: CREATE INDEX <name> ON <schema>.<table>.
2805            if ctx.state.index_schema_prefix and \
2806               isinstance(self._table, Table) and self._table._schema:
2807                index_name = Entity(self._table._schema, self._name)
2808                table_name = Entity(self._table.__name__)
2809            else:
2810                index_name = Entity(self._name)
2811                table_name = self._table
2812
2813            ctx.sql(index_name)
2814            if self._using is not None and \
2815               ctx.state.index_using_precedes_table:
2816                ctx.literal(' USING %s' % self._using)  # MySQL style.
2817
2818            (ctx
2819             .literal(' ON ')
2820             .sql(table_name)
2821             .literal(' '))
2822
2823            if self._using is not None and not \
2824               ctx.state.index_using_precedes_table:
2825                ctx.literal('USING %s ' % self._using)  # Postgres/default.
2826
2827            ctx.sql(EnclosedNodeList([
2828                SQL(expr) if isinstance(expr, basestring) else expr
2829                for expr in self._expressions]))
2830            if self._where is not None:
2831                ctx.literal(' WHERE ').sql(self._where)
2832
2833        return ctx
2834
2835
2836class ModelIndex(Index):
2837    def __init__(self, model, fields, unique=False, safe=True, where=None,
2838                 using=None, name=None):
2839        self._model = model
2840        if name is None:
2841            name = self._generate_name_from_fields(model, fields)
2842        if using is None:
2843            for field in fields:
2844                if isinstance(field, Field) and hasattr(field, 'index_type'):
2845                    using = field.index_type
2846        super(ModelIndex, self).__init__(
2847            name=name,
2848            table=model._meta.table,
2849            expressions=fields,
2850            unique=unique,
2851            safe=safe,
2852            where=where,
2853            using=using)
2854
2855    def _generate_name_from_fields(self, model, fields):
2856        accum = []
2857        for field in fields:
2858            if isinstance(field, basestring):
2859                accum.append(field.split()[0])
2860            else:
2861                if isinstance(field, Node) and not isinstance(field, Field):
2862                    field = field.unwrap()
2863                if isinstance(field, Field):
2864                    accum.append(field.column_name)
2865
2866        if not accum:
2867            raise ValueError('Unable to generate a name for the index, please '
2868                             'explicitly specify a name.')
2869
2870        clean_field_names = re.sub(r'[^\w]+', '', '_'.join(accum))
2871        meta = model._meta
2872        prefix = meta.name if meta.legacy_table_names else meta.table_name
2873        return _truncate_constraint_name('_'.join((prefix, clean_field_names)))
2874
2875
2876def _truncate_constraint_name(constraint, maxlen=64):
2877    if len(constraint) > maxlen:
2878        name_hash = hashlib.md5(constraint.encode('utf-8')).hexdigest()
2879        constraint = '%s_%s' % (constraint[:(maxlen - 8)], name_hash[:7])
2880    return constraint
2881
2882
2883# DB-API 2.0 EXCEPTIONS.
2884
2885
2886class PeeweeException(Exception):
2887    def __init__(self, *args):
2888        if args and isinstance(args[0], Exception):
2889            self.orig, args = args[0], args[1:]
2890        super(PeeweeException, self).__init__(*args)
2891class ImproperlyConfigured(PeeweeException): pass
2892class DatabaseError(PeeweeException): pass
2893class DataError(DatabaseError): pass
2894class IntegrityError(DatabaseError): pass
2895class InterfaceError(PeeweeException): pass
2896class InternalError(DatabaseError): pass
2897class NotSupportedError(DatabaseError): pass
2898class OperationalError(DatabaseError): pass
2899class ProgrammingError(DatabaseError): pass
2900
2901
2902class ExceptionWrapper(object):
2903    __slots__ = ('exceptions',)
2904    def __init__(self, exceptions):
2905        self.exceptions = exceptions
2906    def __enter__(self): pass
2907    def __exit__(self, exc_type, exc_value, traceback):
2908        if exc_type is None:
2909            return
2910        # psycopg2.8 shits out a million cute error types. Try to catch em all.
2911        if pg_errors is not None and exc_type.__name__ not in self.exceptions \
2912           and issubclass(exc_type, pg_errors.Error):
2913            exc_type = exc_type.__bases__[0]
2914        if exc_type.__name__ in self.exceptions:
2915            new_type = self.exceptions[exc_type.__name__]
2916            exc_args = exc_value.args
2917            reraise(new_type, new_type(exc_value, *exc_args), traceback)
2918
2919
2920EXCEPTIONS = {
2921    'ConstraintError': IntegrityError,
2922    'DatabaseError': DatabaseError,
2923    'DataError': DataError,
2924    'IntegrityError': IntegrityError,
2925    'InterfaceError': InterfaceError,
2926    'InternalError': InternalError,
2927    'NotSupportedError': NotSupportedError,
2928    'OperationalError': OperationalError,
2929    'ProgrammingError': ProgrammingError,
2930    'TransactionRollbackError': OperationalError}
2931
2932__exception_wrapper__ = ExceptionWrapper(EXCEPTIONS)
2933
2934
2935# DATABASE INTERFACE AND CONNECTION MANAGEMENT.
2936
2937
2938IndexMetadata = collections.namedtuple(
2939    'IndexMetadata',
2940    ('name', 'sql', 'columns', 'unique', 'table'))
2941ColumnMetadata = collections.namedtuple(
2942    'ColumnMetadata',
2943    ('name', 'data_type', 'null', 'primary_key', 'table', 'default'))
2944ForeignKeyMetadata = collections.namedtuple(
2945    'ForeignKeyMetadata',
2946    ('column', 'dest_table', 'dest_column', 'table'))
2947ViewMetadata = collections.namedtuple('ViewMetadata', ('name', 'sql'))
2948
2949
2950class _ConnectionState(object):
2951    def __init__(self, **kwargs):
2952        super(_ConnectionState, self).__init__(**kwargs)
2953        self.reset()
2954
2955    def reset(self):
2956        self.closed = True
2957        self.conn = None
2958        self.ctx = []
2959        self.transactions = []
2960
2961    def set_connection(self, conn):
2962        self.conn = conn
2963        self.closed = False
2964        self.ctx = []
2965        self.transactions = []
2966
2967
2968class _ConnectionLocal(_ConnectionState, threading.local): pass
2969class _NoopLock(object):
2970    __slots__ = ()
2971    def __enter__(self): return self
2972    def __exit__(self, exc_type, exc_val, exc_tb): pass
2973
2974
2975class ConnectionContext(_callable_context_manager):
2976    __slots__ = ('db',)
2977    def __init__(self, db): self.db = db
2978    def __enter__(self):
2979        if self.db.is_closed():
2980            self.db.connect()
2981    def __exit__(self, exc_type, exc_val, exc_tb): self.db.close()
2982
2983
2984class Database(_callable_context_manager):
2985    context_class = Context
2986    field_types = {}
2987    operations = {}
2988    param = '?'
2989    quote = '""'
2990    server_version = None
2991
2992    # Feature toggles.
2993    commit_select = False
2994    compound_select_parentheses = CSQ_PARENTHESES_NEVER
2995    for_update = False
2996    index_schema_prefix = False
2997    index_using_precedes_table = False
2998    limit_max = None
2999    nulls_ordering = False
3000    returning_clause = False
3001    safe_create_index = True
3002    safe_drop_index = True
3003    sequences = False
3004    truncate_table = True
3005
3006    def __init__(self, database, thread_safe=True, autorollback=False,
3007                 field_types=None, operations=None, autocommit=None,
3008                 autoconnect=True, **kwargs):
3009        self._field_types = merge_dict(FIELD, self.field_types)
3010        self._operations = merge_dict(OP, self.operations)
3011        if field_types:
3012            self._field_types.update(field_types)
3013        if operations:
3014            self._operations.update(operations)
3015
3016        self.autoconnect = autoconnect
3017        self.autorollback = autorollback
3018        self.thread_safe = thread_safe
3019        if thread_safe:
3020            self._state = _ConnectionLocal()
3021            self._lock = threading.RLock()
3022        else:
3023            self._state = _ConnectionState()
3024            self._lock = _NoopLock()
3025
3026        if autocommit is not None:
3027            __deprecated__('Peewee no longer uses the "autocommit" option, as '
3028                           'the semantics now require it to always be True. '
3029                           'Because some database-drivers also use the '
3030                           '"autocommit" parameter, you are receiving a '
3031                           'warning so you may update your code and remove '
3032                           'the parameter, as in the future, specifying '
3033                           'autocommit could impact the behavior of the '
3034                           'database driver you are using.')
3035
3036        self.connect_params = {}
3037        self.init(database, **kwargs)
3038
3039    def init(self, database, **kwargs):
3040        if not self.is_closed():
3041            self.close()
3042        self.database = database
3043        self.connect_params.update(kwargs)
3044        self.deferred = not bool(database)
3045
3046    def __enter__(self):
3047        if self.is_closed():
3048            self.connect()
3049        ctx = self.atomic()
3050        self._state.ctx.append(ctx)
3051        ctx.__enter__()
3052        return self
3053
3054    def __exit__(self, exc_type, exc_val, exc_tb):
3055        ctx = self._state.ctx.pop()
3056        try:
3057            ctx.__exit__(exc_type, exc_val, exc_tb)
3058        finally:
3059            if not self._state.ctx:
3060                self.close()
3061
3062    def connection_context(self):
3063        return ConnectionContext(self)
3064
3065    def _connect(self):
3066        raise NotImplementedError
3067
3068    def connect(self, reuse_if_open=False):
3069        with self._lock:
3070            if self.deferred:
3071                raise InterfaceError('Error, database must be initialized '
3072                                     'before opening a connection.')
3073            if not self._state.closed:
3074                if reuse_if_open:
3075                    return False
3076                raise OperationalError('Connection already opened.')
3077
3078            self._state.reset()
3079            with __exception_wrapper__:
3080                self._state.set_connection(self._connect())
3081                if self.server_version is None:
3082                    self._set_server_version(self._state.conn)
3083                self._initialize_connection(self._state.conn)
3084        return True
3085
3086    def _initialize_connection(self, conn):
3087        pass
3088
3089    def _set_server_version(self, conn):
3090        self.server_version = 0
3091
3092    def close(self):
3093        with self._lock:
3094            if self.deferred:
3095                raise InterfaceError('Error, database must be initialized '
3096                                     'before opening a connection.')
3097            if self.in_transaction():
3098                raise OperationalError('Attempting to close database while '
3099                                       'transaction is open.')
3100            is_open = not self._state.closed
3101            try:
3102                if is_open:
3103                    with __exception_wrapper__:
3104                        self._close(self._state.conn)
3105            finally:
3106                self._state.reset()
3107            return is_open
3108
3109    def _close(self, conn):
3110        conn.close()
3111
3112    def is_closed(self):
3113        return self._state.closed
3114
3115    def is_connection_usable(self):
3116        return not self._state.closed
3117
3118    def connection(self):
3119        if self.is_closed():
3120            self.connect()
3121        return self._state.conn
3122
3123    def cursor(self, commit=None):
3124        if self.is_closed():
3125            if self.autoconnect:
3126                self.connect()
3127            else:
3128                raise InterfaceError('Error, database connection not opened.')
3129        return self._state.conn.cursor()
3130
3131    def execute_sql(self, sql, params=None, commit=SENTINEL):
3132        logger.debug((sql, params))
3133        if commit is SENTINEL:
3134            if self.in_transaction():
3135                commit = False
3136            elif self.commit_select:
3137                commit = True
3138            else:
3139                commit = not sql[:6].lower().startswith('select')
3140
3141        with __exception_wrapper__:
3142            cursor = self.cursor(commit)
3143            try:
3144                cursor.execute(sql, params or ())
3145            except Exception:
3146                if self.autorollback and not self.in_transaction():
3147                    self.rollback()
3148                raise
3149            else:
3150                if commit and not self.in_transaction():
3151                    self.commit()
3152        return cursor
3153
3154    def execute(self, query, commit=SENTINEL, **context_options):
3155        ctx = self.get_sql_context(**context_options)
3156        sql, params = ctx.sql(query).query()
3157        return self.execute_sql(sql, params, commit=commit)
3158
3159    def get_context_options(self):
3160        return {
3161            'field_types': self._field_types,
3162            'operations': self._operations,
3163            'param': self.param,
3164            'quote': self.quote,
3165            'compound_select_parentheses': self.compound_select_parentheses,
3166            'conflict_statement': self.conflict_statement,
3167            'conflict_update': self.conflict_update,
3168            'for_update': self.for_update,
3169            'index_schema_prefix': self.index_schema_prefix,
3170            'index_using_precedes_table': self.index_using_precedes_table,
3171            'limit_max': self.limit_max,
3172            'nulls_ordering': self.nulls_ordering,
3173        }
3174
3175    def get_sql_context(self, **context_options):
3176        context = self.get_context_options()
3177        if context_options:
3178            context.update(context_options)
3179        return self.context_class(**context)
3180
3181    def conflict_statement(self, on_conflict, query):
3182        raise NotImplementedError
3183
3184    def conflict_update(self, on_conflict, query):
3185        raise NotImplementedError
3186
3187    def _build_on_conflict_update(self, on_conflict, query):
3188        if on_conflict._conflict_target:
3189            stmt = SQL('ON CONFLICT')
3190            target = EnclosedNodeList([
3191                Entity(col) if isinstance(col, basestring) else col
3192                for col in on_conflict._conflict_target])
3193            if on_conflict._conflict_where is not None:
3194                target = NodeList([target, SQL('WHERE'),
3195                                   on_conflict._conflict_where])
3196        else:
3197            stmt = SQL('ON CONFLICT ON CONSTRAINT')
3198            target = on_conflict._conflict_constraint
3199            if isinstance(target, basestring):
3200                target = Entity(target)
3201
3202        updates = []
3203        if on_conflict._preserve:
3204            for column in on_conflict._preserve:
3205                excluded = NodeList((SQL('EXCLUDED'), ensure_entity(column)),
3206                                    glue='.')
3207                expression = NodeList((ensure_entity(column), SQL('='),
3208                                       excluded))
3209                updates.append(expression)
3210
3211        if on_conflict._update:
3212            for k, v in on_conflict._update.items():
3213                if not isinstance(v, Node):
3214                    # Attempt to resolve string field-names to their respective
3215                    # field object, to apply data-type conversions.
3216                    if isinstance(k, basestring):
3217                        k = getattr(query.table, k)
3218                    if isinstance(k, Field):
3219                        v = k.to_value(v)
3220                    else:
3221                        v = Value(v, unpack=False)
3222                else:
3223                    v = QualifiedNames(v)
3224                updates.append(NodeList((ensure_entity(k), SQL('='), v)))
3225
3226        parts = [stmt, target, SQL('DO UPDATE SET'), CommaNodeList(updates)]
3227        if on_conflict._where:
3228            parts.extend((SQL('WHERE'), QualifiedNames(on_conflict._where)))
3229
3230        return NodeList(parts)
3231
3232    def last_insert_id(self, cursor, query_type=None):
3233        return cursor.lastrowid
3234
3235    def rows_affected(self, cursor):
3236        return cursor.rowcount
3237
3238    def default_values_insert(self, ctx):
3239        return ctx.literal('DEFAULT VALUES')
3240
3241    def session_start(self):
3242        with self._lock:
3243            return self.transaction().__enter__()
3244
3245    def session_commit(self):
3246        with self._lock:
3247            try:
3248                txn = self.pop_transaction()
3249            except IndexError:
3250                return False
3251            txn.commit(begin=self.in_transaction())
3252            return True
3253
3254    def session_rollback(self):
3255        with self._lock:
3256            try:
3257                txn = self.pop_transaction()
3258            except IndexError:
3259                return False
3260            txn.rollback(begin=self.in_transaction())
3261            return True
3262
3263    def in_transaction(self):
3264        return bool(self._state.transactions)
3265
3266    def push_transaction(self, transaction):
3267        self._state.transactions.append(transaction)
3268
3269    def pop_transaction(self):
3270        return self._state.transactions.pop()
3271
3272    def transaction_depth(self):
3273        return len(self._state.transactions)
3274
3275    def top_transaction(self):
3276        if self._state.transactions:
3277            return self._state.transactions[-1]
3278
3279    def atomic(self, *args, **kwargs):
3280        return _atomic(self, *args, **kwargs)
3281
3282    def manual_commit(self):
3283        return _manual(self)
3284
3285    def transaction(self, *args, **kwargs):
3286        return _transaction(self, *args, **kwargs)
3287
3288    def savepoint(self):
3289        return _savepoint(self)
3290
3291    def begin(self):
3292        if self.is_closed():
3293            self.connect()
3294
3295    def commit(self):
3296        with __exception_wrapper__:
3297            return self._state.conn.commit()
3298
3299    def rollback(self):
3300        with __exception_wrapper__:
3301            return self._state.conn.rollback()
3302
3303    def batch_commit(self, it, n):
3304        for group in chunked(it, n):
3305            with self.atomic():
3306                for obj in group:
3307                    yield obj
3308
3309    def table_exists(self, table_name, schema=None):
3310        return table_name in self.get_tables(schema=schema)
3311
3312    def get_tables(self, schema=None):
3313        raise NotImplementedError
3314
3315    def get_indexes(self, table, schema=None):
3316        raise NotImplementedError
3317
3318    def get_columns(self, table, schema=None):
3319        raise NotImplementedError
3320
3321    def get_primary_keys(self, table, schema=None):
3322        raise NotImplementedError
3323
3324    def get_foreign_keys(self, table, schema=None):
3325        raise NotImplementedError
3326
3327    def sequence_exists(self, seq):
3328        raise NotImplementedError
3329
3330    def create_tables(self, models, **options):
3331        for model in sort_models(models):
3332            model.create_table(**options)
3333
3334    def drop_tables(self, models, **kwargs):
3335        for model in reversed(sort_models(models)):
3336            model.drop_table(**kwargs)
3337
3338    def extract_date(self, date_part, date_field):
3339        raise NotImplementedError
3340
3341    def truncate_date(self, date_part, date_field):
3342        raise NotImplementedError
3343
3344    def to_timestamp(self, date_field):
3345        raise NotImplementedError
3346
3347    def from_timestamp(self, date_field):
3348        raise NotImplementedError
3349
3350    def random(self):
3351        return fn.random()
3352
3353    def bind(self, models, bind_refs=True, bind_backrefs=True):
3354        for model in models:
3355            model.bind(self, bind_refs=bind_refs, bind_backrefs=bind_backrefs)
3356
3357    def bind_ctx(self, models, bind_refs=True, bind_backrefs=True):
3358        return _BoundModelsContext(models, self, bind_refs, bind_backrefs)
3359
3360    def get_noop_select(self, ctx):
3361        return ctx.sql(Select().columns(SQL('0')).where(SQL('0')))
3362
3363
3364def __pragma__(name):
3365    def __get__(self):
3366        return self.pragma(name)
3367    def __set__(self, value):
3368        return self.pragma(name, value)
3369    return property(__get__, __set__)
3370
3371
3372class SqliteDatabase(Database):
3373    field_types = {
3374        'BIGAUTO': FIELD.AUTO,
3375        'BIGINT': FIELD.INT,
3376        'BOOL': FIELD.INT,
3377        'DOUBLE': FIELD.FLOAT,
3378        'SMALLINT': FIELD.INT,
3379        'UUID': FIELD.TEXT}
3380    operations = {
3381        'LIKE': 'GLOB',
3382        'ILIKE': 'LIKE'}
3383    index_schema_prefix = True
3384    limit_max = -1
3385    server_version = __sqlite_version__
3386    truncate_table = False
3387
3388    def __init__(self, database, *args, **kwargs):
3389        self._pragmas = kwargs.pop('pragmas', ())
3390        super(SqliteDatabase, self).__init__(database, *args, **kwargs)
3391        self._aggregates = {}
3392        self._collations = {}
3393        self._functions = {}
3394        self._window_functions = {}
3395        self._table_functions = []
3396        self._extensions = set()
3397        self._attached = {}
3398        self.register_function(_sqlite_date_part, 'date_part', 2)
3399        self.register_function(_sqlite_date_trunc, 'date_trunc', 2)
3400        self.nulls_ordering = self.server_version >= (3, 30, 0)
3401
3402    def init(self, database, pragmas=None, timeout=5, **kwargs):
3403        if pragmas is not None:
3404            self._pragmas = pragmas
3405        if isinstance(self._pragmas, dict):
3406            self._pragmas = list(self._pragmas.items())
3407        self._timeout = timeout
3408        super(SqliteDatabase, self).init(database, **kwargs)
3409
3410    def _set_server_version(self, conn):
3411        pass
3412
3413    def _connect(self):
3414        if sqlite3 is None:
3415            raise ImproperlyConfigured('SQLite driver not installed!')
3416        conn = sqlite3.connect(self.database, timeout=self._timeout,
3417                               isolation_level=None, **self.connect_params)
3418        try:
3419            self._add_conn_hooks(conn)
3420        except:
3421            conn.close()
3422            raise
3423        return conn
3424
3425    def _add_conn_hooks(self, conn):
3426        if self._attached:
3427            self._attach_databases(conn)
3428        if self._pragmas:
3429            self._set_pragmas(conn)
3430        self._load_aggregates(conn)
3431        self._load_collations(conn)
3432        self._load_functions(conn)
3433        if self.server_version >= (3, 25, 0):
3434            self._load_window_functions(conn)
3435        if self._table_functions:
3436            for table_function in self._table_functions:
3437                table_function.register(conn)
3438        if self._extensions:
3439            self._load_extensions(conn)
3440
3441    def _set_pragmas(self, conn):
3442        cursor = conn.cursor()
3443        for pragma, value in self._pragmas:
3444            cursor.execute('PRAGMA %s = %s;' % (pragma, value))
3445        cursor.close()
3446
3447    def _attach_databases(self, conn):
3448        cursor = conn.cursor()
3449        for name, db in self._attached.items():
3450            cursor.execute('ATTACH DATABASE "%s" AS "%s"' % (db, name))
3451        cursor.close()
3452
3453    def pragma(self, key, value=SENTINEL, permanent=False, schema=None):
3454        if schema is not None:
3455            key = '"%s".%s' % (schema, key)
3456        sql = 'PRAGMA %s' % key
3457        if value is not SENTINEL:
3458            sql += ' = %s' % (value or 0)
3459            if permanent:
3460                pragmas = dict(self._pragmas or ())
3461                pragmas[key] = value
3462                self._pragmas = list(pragmas.items())
3463        elif permanent:
3464            raise ValueError('Cannot specify a permanent pragma without value')
3465        row = self.execute_sql(sql).fetchone()
3466        if row:
3467            return row[0]
3468
3469    cache_size = __pragma__('cache_size')
3470    foreign_keys = __pragma__('foreign_keys')
3471    journal_mode = __pragma__('journal_mode')
3472    journal_size_limit = __pragma__('journal_size_limit')
3473    mmap_size = __pragma__('mmap_size')
3474    page_size = __pragma__('page_size')
3475    read_uncommitted = __pragma__('read_uncommitted')
3476    synchronous = __pragma__('synchronous')
3477    wal_autocheckpoint = __pragma__('wal_autocheckpoint')
3478
3479    @property
3480    def timeout(self):
3481        return self._timeout
3482
3483    @timeout.setter
3484    def timeout(self, seconds):
3485        if self._timeout == seconds:
3486            return
3487
3488        self._timeout = seconds
3489        if not self.is_closed():
3490            # PySQLite multiplies user timeout by 1000, but the unit of the
3491            # timeout PRAGMA is actually milliseconds.
3492            self.execute_sql('PRAGMA busy_timeout=%d;' % (seconds * 1000))
3493
3494    def _load_aggregates(self, conn):
3495        for name, (klass, num_params) in self._aggregates.items():
3496            conn.create_aggregate(name, num_params, klass)
3497
3498    def _load_collations(self, conn):
3499        for name, fn in self._collations.items():
3500            conn.create_collation(name, fn)
3501
3502    def _load_functions(self, conn):
3503        for name, (fn, num_params) in self._functions.items():
3504            conn.create_function(name, num_params, fn)
3505
3506    def _load_window_functions(self, conn):
3507        for name, (klass, num_params) in self._window_functions.items():
3508            conn.create_window_function(name, num_params, klass)
3509
3510    def register_aggregate(self, klass, name=None, num_params=-1):
3511        self._aggregates[name or klass.__name__.lower()] = (klass, num_params)
3512        if not self.is_closed():
3513            self._load_aggregates(self.connection())
3514
3515    def aggregate(self, name=None, num_params=-1):
3516        def decorator(klass):
3517            self.register_aggregate(klass, name, num_params)
3518            return klass
3519        return decorator
3520
3521    def register_collation(self, fn, name=None):
3522        name = name or fn.__name__
3523        def _collation(*args):
3524            expressions = args + (SQL('collate %s' % name),)
3525            return NodeList(expressions)
3526        fn.collation = _collation
3527        self._collations[name] = fn
3528        if not self.is_closed():
3529            self._load_collations(self.connection())
3530
3531    def collation(self, name=None):
3532        def decorator(fn):
3533            self.register_collation(fn, name)
3534            return fn
3535        return decorator
3536
3537    def register_function(self, fn, name=None, num_params=-1):
3538        self._functions[name or fn.__name__] = (fn, num_params)
3539        if not self.is_closed():
3540            self._load_functions(self.connection())
3541
3542    def func(self, name=None, num_params=-1):
3543        def decorator(fn):
3544            self.register_function(fn, name, num_params)
3545            return fn
3546        return decorator
3547
3548    def register_window_function(self, klass, name=None, num_params=-1):
3549        name = name or klass.__name__.lower()
3550        self._window_functions[name] = (klass, num_params)
3551        if not self.is_closed():
3552            self._load_window_functions(self.connection())
3553
3554    def window_function(self, name=None, num_params=-1):
3555        def decorator(klass):
3556            self.register_window_function(klass, name, num_params)
3557            return klass
3558        return decorator
3559
3560    def register_table_function(self, klass, name=None):
3561        if name is not None:
3562            klass.name = name
3563        self._table_functions.append(klass)
3564        if not self.is_closed():
3565            klass.register(self.connection())
3566
3567    def table_function(self, name=None):
3568        def decorator(klass):
3569            self.register_table_function(klass, name)
3570            return klass
3571        return decorator
3572
3573    def unregister_aggregate(self, name):
3574        del(self._aggregates[name])
3575
3576    def unregister_collation(self, name):
3577        del(self._collations[name])
3578
3579    def unregister_function(self, name):
3580        del(self._functions[name])
3581
3582    def unregister_window_function(self, name):
3583        del(self._window_functions[name])
3584
3585    def unregister_table_function(self, name):
3586        for idx, klass in enumerate(self._table_functions):
3587            if klass.name == name:
3588                break
3589        else:
3590            return False
3591        self._table_functions.pop(idx)
3592        return True
3593
3594    def _load_extensions(self, conn):
3595        conn.enable_load_extension(True)
3596        for extension in self._extensions:
3597            conn.load_extension(extension)
3598
3599    def load_extension(self, extension):
3600        self._extensions.add(extension)
3601        if not self.is_closed():
3602            conn = self.connection()
3603            conn.enable_load_extension(True)
3604            conn.load_extension(extension)
3605
3606    def unload_extension(self, extension):
3607        self._extensions.remove(extension)
3608
3609    def attach(self, filename, name):
3610        if name in self._attached:
3611            if self._attached[name] == filename:
3612                return False
3613            raise OperationalError('schema "%s" already attached.' % name)
3614
3615        self._attached[name] = filename
3616        if not self.is_closed():
3617            self.execute_sql('ATTACH DATABASE "%s" AS "%s"' % (filename, name))
3618        return True
3619
3620    def detach(self, name):
3621        if name not in self._attached:
3622            return False
3623
3624        del self._attached[name]
3625        if not self.is_closed():
3626            self.execute_sql('DETACH DATABASE "%s"' % name)
3627        return True
3628
3629    def begin(self, lock_type=None):
3630        statement = 'BEGIN %s' % lock_type if lock_type else 'BEGIN'
3631        self.execute_sql(statement, commit=False)
3632
3633    def get_tables(self, schema=None):
3634        schema = schema or 'main'
3635        cursor = self.execute_sql('SELECT name FROM "%s".sqlite_master WHERE '
3636                                  'type=? ORDER BY name' % schema, ('table',))
3637        return [row for row, in cursor.fetchall()]
3638
3639    def get_views(self, schema=None):
3640        sql = ('SELECT name, sql FROM "%s".sqlite_master WHERE type=? '
3641               'ORDER BY name') % (schema or 'main')
3642        return [ViewMetadata(*row) for row in self.execute_sql(sql, ('view',))]
3643
3644    def get_indexes(self, table, schema=None):
3645        schema = schema or 'main'
3646        query = ('SELECT name, sql FROM "%s".sqlite_master '
3647                 'WHERE tbl_name = ? AND type = ? ORDER BY name') % schema
3648        cursor = self.execute_sql(query, (table, 'index'))
3649        index_to_sql = dict(cursor.fetchall())
3650
3651        # Determine which indexes have a unique constraint.
3652        unique_indexes = set()
3653        cursor = self.execute_sql('PRAGMA "%s".index_list("%s")' %
3654                                  (schema, table))
3655        for row in cursor.fetchall():
3656            name = row[1]
3657            is_unique = int(row[2]) == 1
3658            if is_unique:
3659                unique_indexes.add(name)
3660
3661        # Retrieve the indexed columns.
3662        index_columns = {}
3663        for index_name in sorted(index_to_sql):
3664            cursor = self.execute_sql('PRAGMA "%s".index_info("%s")' %
3665                                      (schema, index_name))
3666            index_columns[index_name] = [row[2] for row in cursor.fetchall()]
3667
3668        return [
3669            IndexMetadata(
3670                name,
3671                index_to_sql[name],
3672                index_columns[name],
3673                name in unique_indexes,
3674                table)
3675            for name in sorted(index_to_sql)]
3676
3677    def get_columns(self, table, schema=None):
3678        cursor = self.execute_sql('PRAGMA "%s".table_info("%s")' %
3679                                  (schema or 'main', table))
3680        return [ColumnMetadata(r[1], r[2], not r[3], bool(r[5]), table, r[4])
3681                for r in cursor.fetchall()]
3682
3683    def get_primary_keys(self, table, schema=None):
3684        cursor = self.execute_sql('PRAGMA "%s".table_info("%s")' %
3685                                  (schema or 'main', table))
3686        return [row[1] for row in filter(lambda r: r[-1], cursor.fetchall())]
3687
3688    def get_foreign_keys(self, table, schema=None):
3689        cursor = self.execute_sql('PRAGMA "%s".foreign_key_list("%s")' %
3690                                  (schema or 'main', table))
3691        return [ForeignKeyMetadata(row[3], row[2], row[4], table)
3692                for row in cursor.fetchall()]
3693
3694    def get_binary_type(self):
3695        return sqlite3.Binary
3696
3697    def conflict_statement(self, on_conflict, query):
3698        action = on_conflict._action.lower() if on_conflict._action else ''
3699        if action and action not in ('nothing', 'update'):
3700            return SQL('INSERT OR %s' % on_conflict._action.upper())
3701
3702    def conflict_update(self, oc, query):
3703        # Sqlite prior to 3.24.0 does not support Postgres-style upsert.
3704        if self.server_version < (3, 24, 0) and \
3705           any((oc._preserve, oc._update, oc._where, oc._conflict_target,
3706                oc._conflict_constraint)):
3707            raise ValueError('SQLite does not support specifying which values '
3708                             'to preserve or update.')
3709
3710        action = oc._action.lower() if oc._action else ''
3711        if action and action not in ('nothing', 'update', ''):
3712            return
3713
3714        if action == 'nothing':
3715            return SQL('ON CONFLICT DO NOTHING')
3716        elif not oc._update and not oc._preserve:
3717            raise ValueError('If you are not performing any updates (or '
3718                             'preserving any INSERTed values), then the '
3719                             'conflict resolution action should be set to '
3720                             '"NOTHING".')
3721        elif oc._conflict_constraint:
3722            raise ValueError('SQLite does not support specifying named '
3723                             'constraints for conflict resolution.')
3724        elif not oc._conflict_target:
3725            raise ValueError('SQLite requires that a conflict target be '
3726                             'specified when doing an upsert.')
3727
3728        return self._build_on_conflict_update(oc, query)
3729
3730    def extract_date(self, date_part, date_field):
3731        return fn.date_part(date_part, date_field, python_value=int)
3732
3733    def truncate_date(self, date_part, date_field):
3734        return fn.date_trunc(date_part, date_field,
3735                             python_value=simple_date_time)
3736
3737    def to_timestamp(self, date_field):
3738        return fn.strftime('%s', date_field).cast('integer')
3739
3740    def from_timestamp(self, date_field):
3741        return fn.datetime(date_field, 'unixepoch')
3742
3743
3744class PostgresqlDatabase(Database):
3745    field_types = {
3746        'AUTO': 'SERIAL',
3747        'BIGAUTO': 'BIGSERIAL',
3748        'BLOB': 'BYTEA',
3749        'BOOL': 'BOOLEAN',
3750        'DATETIME': 'TIMESTAMP',
3751        'DECIMAL': 'NUMERIC',
3752        'DOUBLE': 'DOUBLE PRECISION',
3753        'UUID': 'UUID',
3754        'UUIDB': 'BYTEA'}
3755    operations = {'REGEXP': '~', 'IREGEXP': '~*'}
3756    param = '%s'
3757
3758    commit_select = True
3759    compound_select_parentheses = CSQ_PARENTHESES_ALWAYS
3760    for_update = True
3761    nulls_ordering = True
3762    returning_clause = True
3763    safe_create_index = False
3764    sequences = True
3765
3766    def init(self, database, register_unicode=True, encoding=None,
3767             isolation_level=None, **kwargs):
3768        self._register_unicode = register_unicode
3769        self._encoding = encoding
3770        self._isolation_level = isolation_level
3771        super(PostgresqlDatabase, self).init(database, **kwargs)
3772
3773    def _connect(self):
3774        if psycopg2 is None:
3775            raise ImproperlyConfigured('Postgres driver not installed!')
3776        conn = psycopg2.connect(database=self.database, **self.connect_params)
3777        if self._register_unicode:
3778            pg_extensions.register_type(pg_extensions.UNICODE, conn)
3779            pg_extensions.register_type(pg_extensions.UNICODEARRAY, conn)
3780        if self._encoding:
3781            conn.set_client_encoding(self._encoding)
3782        if self._isolation_level:
3783            conn.set_isolation_level(self._isolation_level)
3784        return conn
3785
3786    def _set_server_version(self, conn):
3787        self.server_version = conn.server_version
3788        if self.server_version >= 90600:
3789            self.safe_create_index = True
3790
3791    def is_connection_usable(self):
3792        if self._state.closed:
3793            return False
3794
3795        # Returns True if we are idle, running a command, or in an active
3796        # connection. If the connection is in an error state or the connection
3797        # is otherwise unusable, return False.
3798        txn_status = self._state.conn.get_transaction_status()
3799        return txn_status < pg_extensions.TRANSACTION_STATUS_INERROR
3800
3801    def last_insert_id(self, cursor, query_type=None):
3802        try:
3803            return cursor if query_type != Insert.SIMPLE else cursor[0][0]
3804        except (IndexError, KeyError, TypeError):
3805            pass
3806
3807    def get_tables(self, schema=None):
3808        query = ('SELECT tablename FROM pg_catalog.pg_tables '
3809                 'WHERE schemaname = %s ORDER BY tablename')
3810        cursor = self.execute_sql(query, (schema or 'public',))
3811        return [table for table, in cursor.fetchall()]
3812
3813    def get_views(self, schema=None):
3814        query = ('SELECT viewname, definition FROM pg_catalog.pg_views '
3815                 'WHERE schemaname = %s ORDER BY viewname')
3816        cursor = self.execute_sql(query, (schema or 'public',))
3817        return [ViewMetadata(view_name, sql.strip(' \t;'))
3818                for (view_name, sql) in cursor.fetchall()]
3819
3820    def get_indexes(self, table, schema=None):
3821        query = """
3822            SELECT
3823                i.relname, idxs.indexdef, idx.indisunique,
3824                array_to_string(ARRAY(
3825                    SELECT pg_get_indexdef(idx.indexrelid, k + 1, TRUE)
3826                    FROM generate_subscripts(idx.indkey, 1) AS k
3827                    ORDER BY k), ',')
3828            FROM pg_catalog.pg_class AS t
3829            INNER JOIN pg_catalog.pg_index AS idx ON t.oid = idx.indrelid
3830            INNER JOIN pg_catalog.pg_class AS i ON idx.indexrelid = i.oid
3831            INNER JOIN pg_catalog.pg_indexes AS idxs ON
3832                (idxs.tablename = t.relname AND idxs.indexname = i.relname)
3833            WHERE t.relname = %s AND t.relkind = %s AND idxs.schemaname = %s
3834            ORDER BY idx.indisunique DESC, i.relname;"""
3835        cursor = self.execute_sql(query, (table, 'r', schema or 'public'))
3836        return [IndexMetadata(name, sql.rstrip(' ;'), columns.split(','),
3837                              is_unique, table)
3838                for name, sql, is_unique, columns in cursor.fetchall()]
3839
3840    def get_columns(self, table, schema=None):
3841        query = """
3842            SELECT column_name, is_nullable, data_type, column_default
3843            FROM information_schema.columns
3844            WHERE table_name = %s AND table_schema = %s
3845            ORDER BY ordinal_position"""
3846        cursor = self.execute_sql(query, (table, schema or 'public'))
3847        pks = set(self.get_primary_keys(table, schema))
3848        return [ColumnMetadata(name, dt, null == 'YES', name in pks, table, df)
3849                for name, null, dt, df in cursor.fetchall()]
3850
3851    def get_primary_keys(self, table, schema=None):
3852        query = """
3853            SELECT kc.column_name
3854            FROM information_schema.table_constraints AS tc
3855            INNER JOIN information_schema.key_column_usage AS kc ON (
3856                tc.table_name = kc.table_name AND
3857                tc.table_schema = kc.table_schema AND
3858                tc.constraint_name = kc.constraint_name)
3859            WHERE
3860                tc.constraint_type = %s AND
3861                tc.table_name = %s AND
3862                tc.table_schema = %s"""
3863        ctype = 'PRIMARY KEY'
3864        cursor = self.execute_sql(query, (ctype, table, schema or 'public'))
3865        return [pk for pk, in cursor.fetchall()]
3866
3867    def get_foreign_keys(self, table, schema=None):
3868        sql = """
3869            SELECT DISTINCT
3870                kcu.column_name, ccu.table_name, ccu.column_name
3871            FROM information_schema.table_constraints AS tc
3872            JOIN information_schema.key_column_usage AS kcu
3873                ON (tc.constraint_name = kcu.constraint_name AND
3874                    tc.constraint_schema = kcu.constraint_schema AND
3875                    tc.table_name = kcu.table_name AND
3876                    tc.table_schema = kcu.table_schema)
3877            JOIN information_schema.constraint_column_usage AS ccu
3878                ON (ccu.constraint_name = tc.constraint_name AND
3879                    ccu.constraint_schema = tc.constraint_schema)
3880            WHERE
3881                tc.constraint_type = 'FOREIGN KEY' AND
3882                tc.table_name = %s AND
3883                tc.table_schema = %s"""
3884        cursor = self.execute_sql(sql, (table, schema or 'public'))
3885        return [ForeignKeyMetadata(row[0], row[1], row[2], table)
3886                for row in cursor.fetchall()]
3887
3888    def sequence_exists(self, sequence):
3889        res = self.execute_sql("""
3890            SELECT COUNT(*) FROM pg_class, pg_namespace
3891            WHERE relkind='S'
3892                AND pg_class.relnamespace = pg_namespace.oid
3893                AND relname=%s""", (sequence,))
3894        return bool(res.fetchone()[0])
3895
3896    def get_binary_type(self):
3897        return psycopg2.Binary
3898
3899    def conflict_statement(self, on_conflict, query):
3900        return
3901
3902    def conflict_update(self, oc, query):
3903        action = oc._action.lower() if oc._action else ''
3904        if action in ('ignore', 'nothing'):
3905            return SQL('ON CONFLICT DO NOTHING')
3906        elif action and action != 'update':
3907            raise ValueError('The only supported actions for conflict '
3908                             'resolution with Postgresql are "ignore" or '
3909                             '"update".')
3910        elif not oc._update and not oc._preserve:
3911            raise ValueError('If you are not performing any updates (or '
3912                             'preserving any INSERTed values), then the '
3913                             'conflict resolution action should be set to '
3914                             '"IGNORE".')
3915        elif not (oc._conflict_target or oc._conflict_constraint):
3916            raise ValueError('Postgres requires that a conflict target be '
3917                             'specified when doing an upsert.')
3918
3919        return self._build_on_conflict_update(oc, query)
3920
3921    def extract_date(self, date_part, date_field):
3922        return fn.EXTRACT(NodeList((date_part, SQL('FROM'), date_field)))
3923
3924    def truncate_date(self, date_part, date_field):
3925        return fn.DATE_TRUNC(date_part, date_field)
3926
3927    def to_timestamp(self, date_field):
3928        return self.extract_date('EPOCH', date_field)
3929
3930    def from_timestamp(self, date_field):
3931        # Ironically, here, Postgres means "to the Postgresql timestamp type".
3932        return fn.to_timestamp(date_field)
3933
3934    def get_noop_select(self, ctx):
3935        return ctx.sql(Select().columns(SQL('0')).where(SQL('false')))
3936
3937    def set_time_zone(self, timezone):
3938        self.execute_sql('set time zone "%s";' % timezone)
3939
3940
3941class MySQLDatabase(Database):
3942    field_types = {
3943        'AUTO': 'INTEGER AUTO_INCREMENT',
3944        'BIGAUTO': 'BIGINT AUTO_INCREMENT',
3945        'BOOL': 'BOOL',
3946        'DECIMAL': 'NUMERIC',
3947        'DOUBLE': 'DOUBLE PRECISION',
3948        'FLOAT': 'FLOAT',
3949        'UUID': 'VARCHAR(40)',
3950        'UUIDB': 'VARBINARY(16)'}
3951    operations = {
3952        'LIKE': 'LIKE BINARY',
3953        'ILIKE': 'LIKE',
3954        'REGEXP': 'REGEXP BINARY',
3955        'IREGEXP': 'REGEXP',
3956        'XOR': 'XOR'}
3957    param = '%s'
3958    quote = '``'
3959
3960    commit_select = True
3961    compound_select_parentheses = CSQ_PARENTHESES_UNNESTED
3962    for_update = True
3963    index_using_precedes_table = True
3964    limit_max = 2 ** 64 - 1
3965    safe_create_index = False
3966    safe_drop_index = False
3967    sql_mode = 'PIPES_AS_CONCAT'
3968
3969    def init(self, database, **kwargs):
3970        params = {
3971            'charset': 'utf8',
3972            'sql_mode': self.sql_mode,
3973            'use_unicode': True}
3974        params.update(kwargs)
3975        if 'password' in params and mysql_passwd:
3976            params['passwd'] = params.pop('password')
3977        super(MySQLDatabase, self).init(database, **params)
3978
3979    def _connect(self):
3980        if mysql is None:
3981            raise ImproperlyConfigured('MySQL driver not installed!')
3982        conn = mysql.connect(db=self.database, **self.connect_params)
3983        return conn
3984
3985    def _set_server_version(self, conn):
3986        try:
3987            version_raw = conn.server_version
3988        except AttributeError:
3989            version_raw = conn.get_server_info()
3990        self.server_version = self._extract_server_version(version_raw)
3991
3992    def _extract_server_version(self, version):
3993        version = version.lower()
3994        if 'maria' in version:
3995            match_obj = re.search(r'(1\d\.\d+\.\d+)', version)
3996        else:
3997            match_obj = re.search(r'(\d\.\d+\.\d+)', version)
3998        if match_obj is not None:
3999            return tuple(int(num) for num in match_obj.groups()[0].split('.'))
4000
4001        warnings.warn('Unable to determine MySQL version: "%s"' % version)
4002        return (0, 0, 0)  # Unable to determine version!
4003
4004    def default_values_insert(self, ctx):
4005        return ctx.literal('() VALUES ()')
4006
4007    def get_tables(self, schema=None):
4008        query = ('SELECT table_name FROM information_schema.tables '
4009                 'WHERE table_schema = DATABASE() AND table_type != %s '
4010                 'ORDER BY table_name')
4011        return [table for table, in self.execute_sql(query, ('VIEW',))]
4012
4013    def get_views(self, schema=None):
4014        query = ('SELECT table_name, view_definition '
4015                 'FROM information_schema.views '
4016                 'WHERE table_schema = DATABASE() ORDER BY table_name')
4017        cursor = self.execute_sql(query)
4018        return [ViewMetadata(*row) for row in cursor.fetchall()]
4019
4020    def get_indexes(self, table, schema=None):
4021        cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table)
4022        unique = set()
4023        indexes = {}
4024        for row in cursor.fetchall():
4025            if not row[1]:
4026                unique.add(row[2])
4027            indexes.setdefault(row[2], [])
4028            indexes[row[2]].append(row[4])
4029        return [IndexMetadata(name, None, indexes[name], name in unique, table)
4030                for name in indexes]
4031
4032    def get_columns(self, table, schema=None):
4033        sql = """
4034            SELECT column_name, is_nullable, data_type, column_default
4035            FROM information_schema.columns
4036            WHERE table_name = %s AND table_schema = DATABASE()"""
4037        cursor = self.execute_sql(sql, (table,))
4038        pks = set(self.get_primary_keys(table))
4039        return [ColumnMetadata(name, dt, null == 'YES', name in pks, table, df)
4040                for name, null, dt, df in cursor.fetchall()]
4041
4042    def get_primary_keys(self, table, schema=None):
4043        cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table)
4044        return [row[4] for row in
4045                filter(lambda row: row[2] == 'PRIMARY', cursor.fetchall())]
4046
4047    def get_foreign_keys(self, table, schema=None):
4048        query = """
4049            SELECT column_name, referenced_table_name, referenced_column_name
4050            FROM information_schema.key_column_usage
4051            WHERE table_name = %s
4052                AND table_schema = DATABASE()
4053                AND referenced_table_name IS NOT NULL
4054                AND referenced_column_name IS NOT NULL"""
4055        cursor = self.execute_sql(query, (table,))
4056        return [
4057            ForeignKeyMetadata(column, dest_table, dest_column, table)
4058            for column, dest_table, dest_column in cursor.fetchall()]
4059
4060    def get_binary_type(self):
4061        return mysql.Binary
4062
4063    def conflict_statement(self, on_conflict, query):
4064        if not on_conflict._action: return
4065
4066        action = on_conflict._action.lower()
4067        if action == 'replace':
4068            return SQL('REPLACE')
4069        elif action == 'ignore':
4070            return SQL('INSERT IGNORE')
4071        elif action != 'update':
4072            raise ValueError('Un-supported action for conflict resolution. '
4073                             'MySQL supports REPLACE, IGNORE and UPDATE.')
4074
4075    def conflict_update(self, on_conflict, query):
4076        if on_conflict._where or on_conflict._conflict_target or \
4077           on_conflict._conflict_constraint:
4078            raise ValueError('MySQL does not support the specification of '
4079                             'where clauses or conflict targets for conflict '
4080                             'resolution.')
4081
4082        updates = []
4083        if on_conflict._preserve:
4084            # Here we need to determine which function to use, which varies
4085            # depending on the MySQL server version. MySQL and MariaDB prior to
4086            # 10.3.3 use "VALUES", while MariaDB 10.3.3+ use "VALUE".
4087            version = self.server_version or (0,)
4088            if version[0] == 10 and version >= (10, 3, 3):
4089                VALUE_FN = fn.VALUE
4090            else:
4091                VALUE_FN = fn.VALUES
4092
4093            for column in on_conflict._preserve:
4094                entity = ensure_entity(column)
4095                expression = NodeList((
4096                    ensure_entity(column),
4097                    SQL('='),
4098                    VALUE_FN(entity)))
4099                updates.append(expression)
4100
4101        if on_conflict._update:
4102            for k, v in on_conflict._update.items():
4103                if not isinstance(v, Node):
4104                    # Attempt to resolve string field-names to their respective
4105                    # field object, to apply data-type conversions.
4106                    if isinstance(k, basestring):
4107                        k = getattr(query.table, k)
4108                    if isinstance(k, Field):
4109                        v = k.to_value(v)
4110                    else:
4111                        v = Value(v, unpack=False)
4112                updates.append(NodeList((ensure_entity(k), SQL('='), v)))
4113
4114        if updates:
4115            return NodeList((SQL('ON DUPLICATE KEY UPDATE'),
4116                             CommaNodeList(updates)))
4117
4118    def extract_date(self, date_part, date_field):
4119        return fn.EXTRACT(NodeList((SQL(date_part), SQL('FROM'), date_field)))
4120
4121    def truncate_date(self, date_part, date_field):
4122        return fn.DATE_FORMAT(date_field, __mysql_date_trunc__[date_part],
4123                              python_value=simple_date_time)
4124
4125    def to_timestamp(self, date_field):
4126        return fn.UNIX_TIMESTAMP(date_field)
4127
4128    def from_timestamp(self, date_field):
4129        return fn.FROM_UNIXTIME(date_field)
4130
4131    def random(self):
4132        return fn.rand()
4133
4134    def get_noop_select(self, ctx):
4135        return ctx.literal('DO 0')
4136
4137
4138# TRANSACTION CONTROL.
4139
4140
4141class _manual(_callable_context_manager):
4142    def __init__(self, db):
4143        self.db = db
4144
4145    def __enter__(self):
4146        top = self.db.top_transaction()
4147        if top is not None and not isinstance(top, _manual):
4148            raise ValueError('Cannot enter manual commit block while a '
4149                             'transaction is active.')
4150        self.db.push_transaction(self)
4151
4152    def __exit__(self, exc_type, exc_val, exc_tb):
4153        if self.db.pop_transaction() is not self:
4154            raise ValueError('Transaction stack corrupted while exiting '
4155                             'manual commit block.')
4156
4157
4158class _atomic(_callable_context_manager):
4159    def __init__(self, db, *args, **kwargs):
4160        self.db = db
4161        self._transaction_args = (args, kwargs)
4162
4163    def __enter__(self):
4164        if self.db.transaction_depth() == 0:
4165            args, kwargs = self._transaction_args
4166            self._helper = self.db.transaction(*args, **kwargs)
4167        elif isinstance(self.db.top_transaction(), _manual):
4168            raise ValueError('Cannot enter atomic commit block while in '
4169                             'manual commit mode.')
4170        else:
4171            self._helper = self.db.savepoint()
4172        return self._helper.__enter__()
4173
4174    def __exit__(self, exc_type, exc_val, exc_tb):
4175        return self._helper.__exit__(exc_type, exc_val, exc_tb)
4176
4177
4178class _transaction(_callable_context_manager):
4179    def __init__(self, db, *args, **kwargs):
4180        self.db = db
4181        self._begin_args = (args, kwargs)
4182
4183    def _begin(self):
4184        args, kwargs = self._begin_args
4185        self.db.begin(*args, **kwargs)
4186
4187    def commit(self, begin=True):
4188        self.db.commit()
4189        if begin:
4190            self._begin()
4191
4192    def rollback(self, begin=True):
4193        self.db.rollback()
4194        if begin:
4195            self._begin()
4196
4197    def __enter__(self):
4198        if self.db.transaction_depth() == 0:
4199            self._begin()
4200        self.db.push_transaction(self)
4201        return self
4202
4203    def __exit__(self, exc_type, exc_val, exc_tb):
4204        try:
4205            if exc_type:
4206                self.rollback(False)
4207            elif self.db.transaction_depth() == 1:
4208                try:
4209                    self.commit(False)
4210                except:
4211                    self.rollback(False)
4212                    raise
4213        finally:
4214            self.db.pop_transaction()
4215
4216
4217class _savepoint(_callable_context_manager):
4218    def __init__(self, db, sid=None):
4219        self.db = db
4220        self.sid = sid or 's' + uuid.uuid4().hex
4221        self.quoted_sid = self.sid.join(self.db.quote)
4222
4223    def _begin(self):
4224        self.db.execute_sql('SAVEPOINT %s;' % self.quoted_sid)
4225
4226    def commit(self, begin=True):
4227        self.db.execute_sql('RELEASE SAVEPOINT %s;' % self.quoted_sid)
4228        if begin: self._begin()
4229
4230    def rollback(self):
4231        self.db.execute_sql('ROLLBACK TO SAVEPOINT %s;' % self.quoted_sid)
4232
4233    def __enter__(self):
4234        self._begin()
4235        return self
4236
4237    def __exit__(self, exc_type, exc_val, exc_tb):
4238        if exc_type:
4239            self.rollback()
4240        else:
4241            try:
4242                self.commit(begin=False)
4243            except:
4244                self.rollback()
4245                raise
4246
4247
4248# CURSOR REPRESENTATIONS.
4249
4250
4251class CursorWrapper(object):
4252    def __init__(self, cursor):
4253        self.cursor = cursor
4254        self.count = 0
4255        self.index = 0
4256        self.initialized = False
4257        self.populated = False
4258        self.row_cache = []
4259
4260    def __iter__(self):
4261        if self.populated:
4262            return iter(self.row_cache)
4263        return ResultIterator(self)
4264
4265    def __getitem__(self, item):
4266        if isinstance(item, slice):
4267            stop = item.stop
4268            if stop is None or stop < 0:
4269                self.fill_cache()
4270            else:
4271                self.fill_cache(stop)
4272            return self.row_cache[item]
4273        elif isinstance(item, int):
4274            self.fill_cache(item if item > 0 else 0)
4275            return self.row_cache[item]
4276        else:
4277            raise ValueError('CursorWrapper only supports integer and slice '
4278                             'indexes.')
4279
4280    def __len__(self):
4281        self.fill_cache()
4282        return self.count
4283
4284    def initialize(self):
4285        pass
4286
4287    def iterate(self, cache=True):
4288        row = self.cursor.fetchone()
4289        if row is None:
4290            self.populated = True
4291            self.cursor.close()
4292            raise StopIteration
4293        elif not self.initialized:
4294            self.initialize()  # Lazy initialization.
4295            self.initialized = True
4296        self.count += 1
4297        result = self.process_row(row)
4298        if cache:
4299            self.row_cache.append(result)
4300        return result
4301
4302    def process_row(self, row):
4303        return row
4304
4305    def iterator(self):
4306        """Efficient one-pass iteration over the result set."""
4307        while True:
4308            try:
4309                yield self.iterate(False)
4310            except StopIteration:
4311                return
4312
4313    def fill_cache(self, n=0):
4314        n = n or float('Inf')
4315        if n < 0:
4316            raise ValueError('Negative values are not supported.')
4317
4318        iterator = ResultIterator(self)
4319        iterator.index = self.count
4320        while not self.populated and (n > self.count):
4321            try:
4322                iterator.next()
4323            except StopIteration:
4324                break
4325
4326
4327class DictCursorWrapper(CursorWrapper):
4328    def _initialize_columns(self):
4329        description = self.cursor.description
4330        self.columns = [t[0][t[0].find('.') + 1:].strip('")')
4331                        for t in description]
4332        self.ncols = len(description)
4333
4334    initialize = _initialize_columns
4335
4336    def _row_to_dict(self, row):
4337        result = {}
4338        for i in range(self.ncols):
4339            result.setdefault(self.columns[i], row[i])  # Do not overwrite.
4340        return result
4341
4342    process_row = _row_to_dict
4343
4344
4345class NamedTupleCursorWrapper(CursorWrapper):
4346    def initialize(self):
4347        description = self.cursor.description
4348        self.tuple_class = collections.namedtuple(
4349            'Row',
4350            [col[0][col[0].find('.') + 1:].strip('"') for col in description])
4351
4352    def process_row(self, row):
4353        return self.tuple_class(*row)
4354
4355
4356class ObjectCursorWrapper(DictCursorWrapper):
4357    def __init__(self, cursor, constructor):
4358        super(ObjectCursorWrapper, self).__init__(cursor)
4359        self.constructor = constructor
4360
4361    def process_row(self, row):
4362        row_dict = self._row_to_dict(row)
4363        return self.constructor(**row_dict)
4364
4365
4366class ResultIterator(object):
4367    def __init__(self, cursor_wrapper):
4368        self.cursor_wrapper = cursor_wrapper
4369        self.index = 0
4370
4371    def __iter__(self):
4372        return self
4373
4374    def next(self):
4375        if self.index < self.cursor_wrapper.count:
4376            obj = self.cursor_wrapper.row_cache[self.index]
4377        elif not self.cursor_wrapper.populated:
4378            self.cursor_wrapper.iterate()
4379            obj = self.cursor_wrapper.row_cache[self.index]
4380        else:
4381            raise StopIteration
4382        self.index += 1
4383        return obj
4384
4385    __next__ = next
4386
4387# FIELDS
4388
4389class FieldAccessor(object):
4390    def __init__(self, model, field, name):
4391        self.model = model
4392        self.field = field
4393        self.name = name
4394
4395    def __get__(self, instance, instance_type=None):
4396        if instance is not None:
4397            return instance.__data__.get(self.name)
4398        return self.field
4399
4400    def __set__(self, instance, value):
4401        instance.__data__[self.name] = value
4402        instance._dirty.add(self.name)
4403
4404
4405class ForeignKeyAccessor(FieldAccessor):
4406    def __init__(self, model, field, name):
4407        super(ForeignKeyAccessor, self).__init__(model, field, name)
4408        self.rel_model = field.rel_model
4409
4410    def get_rel_instance(self, instance):
4411        value = instance.__data__.get(self.name)
4412        if value is not None or self.name in instance.__rel__:
4413            if self.name not in instance.__rel__ and self.field.lazy_load:
4414                obj = self.rel_model.get(self.field.rel_field == value)
4415                instance.__rel__[self.name] = obj
4416            return instance.__rel__.get(self.name, value)
4417        elif not self.field.null:
4418            raise self.rel_model.DoesNotExist
4419        return value
4420
4421    def __get__(self, instance, instance_type=None):
4422        if instance is not None:
4423            return self.get_rel_instance(instance)
4424        return self.field
4425
4426    def __set__(self, instance, obj):
4427        if isinstance(obj, self.rel_model):
4428            instance.__data__[self.name] = getattr(obj, self.field.rel_field.name)
4429            instance.__rel__[self.name] = obj
4430        else:
4431            fk_value = instance.__data__.get(self.name)
4432            instance.__data__[self.name] = obj
4433            if obj != fk_value and self.name in instance.__rel__:
4434                del instance.__rel__[self.name]
4435        instance._dirty.add(self.name)
4436
4437
4438class BackrefAccessor(object):
4439    def __init__(self, field):
4440        self.field = field
4441        self.model = field.rel_model
4442        self.rel_model = field.model
4443
4444    def __get__(self, instance, instance_type=None):
4445        if instance is not None:
4446            dest = self.field.rel_field.name
4447            return (self.rel_model
4448                    .select()
4449                    .where(self.field == getattr(instance, dest)))
4450        return self
4451
4452
4453class ObjectIdAccessor(object):
4454    """Gives direct access to the underlying id"""
4455    def __init__(self, field):
4456        self.field = field
4457
4458    def __get__(self, instance, instance_type=None):
4459        if instance is not None:
4460            value = instance.__data__.get(self.field.name)
4461            # Pull the object-id from the related object if it is not set.
4462            if value is None and self.field.name in instance.__rel__:
4463                rel_obj = instance.__rel__[self.field.name]
4464                value = getattr(rel_obj, self.field.rel_field.name)
4465            return value
4466        return self.field
4467
4468    def __set__(self, instance, value):
4469        setattr(instance, self.field.name, value)
4470
4471
4472class Field(ColumnBase):
4473    _field_counter = 0
4474    _order = 0
4475    accessor_class = FieldAccessor
4476    auto_increment = False
4477    default_index_type = None
4478    field_type = 'DEFAULT'
4479    unpack = True
4480
4481    def __init__(self, null=False, index=False, unique=False, column_name=None,
4482                 default=None, primary_key=False, constraints=None,
4483                 sequence=None, collation=None, unindexed=False, choices=None,
4484                 help_text=None, verbose_name=None, index_type=None,
4485                 db_column=None, _hidden=False):
4486        if db_column is not None:
4487            __deprecated__('"db_column" has been deprecated in favor of '
4488                           '"column_name" for Field objects.')
4489            column_name = db_column
4490
4491        self.null = null
4492        self.index = index
4493        self.unique = unique
4494        self.column_name = column_name
4495        self.default = default
4496        self.primary_key = primary_key
4497        self.constraints = constraints  # List of column constraints.
4498        self.sequence = sequence  # Name of sequence, e.g. foo_id_seq.
4499        self.collation = collation
4500        self.unindexed = unindexed
4501        self.choices = choices
4502        self.help_text = help_text
4503        self.verbose_name = verbose_name
4504        self.index_type = index_type or self.default_index_type
4505        self._hidden = _hidden
4506
4507        # Used internally for recovering the order in which Fields were defined
4508        # on the Model class.
4509        Field._field_counter += 1
4510        self._order = Field._field_counter
4511        self._sort_key = (self.primary_key and 1 or 2), self._order
4512
4513    def __hash__(self):
4514        return hash(self.name + '.' + self.model.__name__)
4515
4516    def __repr__(self):
4517        if hasattr(self, 'model') and getattr(self, 'name', None):
4518            return '<%s: %s.%s>' % (type(self).__name__,
4519                                    self.model.__name__,
4520                                    self.name)
4521        return '<%s: (unbound)>' % type(self).__name__
4522
4523    def bind(self, model, name, set_attribute=True):
4524        self.model = model
4525        self.name = self.safe_name = name
4526        self.column_name = self.column_name or name
4527        if set_attribute:
4528            setattr(model, name, self.accessor_class(model, self, name))
4529
4530    @property
4531    def column(self):
4532        return Column(self.model._meta.table, self.column_name)
4533
4534    def adapt(self, value):
4535        return value
4536
4537    def db_value(self, value):
4538        return value if value is None else self.adapt(value)
4539
4540    def python_value(self, value):
4541        return value if value is None else self.adapt(value)
4542
4543    def to_value(self, value):
4544        return Value(value, self.db_value, unpack=False)
4545
4546    def get_sort_key(self, ctx):
4547        return self._sort_key
4548
4549    def __sql__(self, ctx):
4550        return ctx.sql(self.column)
4551
4552    def get_modifiers(self):
4553        pass
4554
4555    def ddl_datatype(self, ctx):
4556        if ctx and ctx.state.field_types:
4557            column_type = ctx.state.field_types.get(self.field_type,
4558                                                    self.field_type)
4559        else:
4560            column_type = self.field_type
4561
4562        modifiers = self.get_modifiers()
4563        if column_type and modifiers:
4564            modifier_literal = ', '.join([str(m) for m in modifiers])
4565            return SQL('%s(%s)' % (column_type, modifier_literal))
4566        else:
4567            return SQL(column_type)
4568
4569    def ddl(self, ctx):
4570        accum = [Entity(self.column_name)]
4571        data_type = self.ddl_datatype(ctx)
4572        if data_type:
4573            accum.append(data_type)
4574        if self.unindexed:
4575            accum.append(SQL('UNINDEXED'))
4576        if not self.null:
4577            accum.append(SQL('NOT NULL'))
4578        if self.primary_key:
4579            accum.append(SQL('PRIMARY KEY'))
4580        if self.sequence:
4581            accum.append(SQL("DEFAULT NEXTVAL('%s')" % self.sequence))
4582        if self.constraints:
4583            accum.extend(self.constraints)
4584        if self.collation:
4585            accum.append(SQL('COLLATE %s' % self.collation))
4586        return NodeList(accum)
4587
4588
4589class IntegerField(Field):
4590    field_type = 'INT'
4591
4592    def adapt(self, value):
4593        try:
4594            return int(value)
4595        except ValueError:
4596            return value
4597
4598
4599class BigIntegerField(IntegerField):
4600    field_type = 'BIGINT'
4601
4602
4603class SmallIntegerField(IntegerField):
4604    field_type = 'SMALLINT'
4605
4606
4607class AutoField(IntegerField):
4608    auto_increment = True
4609    field_type = 'AUTO'
4610
4611    def __init__(self, *args, **kwargs):
4612        if kwargs.get('primary_key') is False:
4613            raise ValueError('%s must always be a primary key.' % type(self))
4614        kwargs['primary_key'] = True
4615        super(AutoField, self).__init__(*args, **kwargs)
4616
4617
4618class BigAutoField(AutoField):
4619    field_type = 'BIGAUTO'
4620
4621
4622class IdentityField(AutoField):
4623    field_type = 'INT GENERATED BY DEFAULT AS IDENTITY'
4624
4625    def __init__(self, generate_always=False, **kwargs):
4626        if generate_always:
4627            self.field_type = 'INT GENERATED ALWAYS AS IDENTITY'
4628        super(IdentityField, self).__init__(**kwargs)
4629
4630
4631class PrimaryKeyField(AutoField):
4632    def __init__(self, *args, **kwargs):
4633        __deprecated__('"PrimaryKeyField" has been renamed to "AutoField". '
4634                       'Please update your code accordingly as this will be '
4635                       'completely removed in a subsequent release.')
4636        super(PrimaryKeyField, self).__init__(*args, **kwargs)
4637
4638
4639class FloatField(Field):
4640    field_type = 'FLOAT'
4641
4642    def adapt(self, value):
4643        try:
4644            return float(value)
4645        except ValueError:
4646            return value
4647
4648
4649class DoubleField(FloatField):
4650    field_type = 'DOUBLE'
4651
4652
4653class DecimalField(Field):
4654    field_type = 'DECIMAL'
4655
4656    def __init__(self, max_digits=10, decimal_places=5, auto_round=False,
4657                 rounding=None, *args, **kwargs):
4658        self.max_digits = max_digits
4659        self.decimal_places = decimal_places
4660        self.auto_round = auto_round
4661        self.rounding = rounding or decimal.DefaultContext.rounding
4662        self._exp = decimal.Decimal(10) ** (-self.decimal_places)
4663        super(DecimalField, self).__init__(*args, **kwargs)
4664
4665    def get_modifiers(self):
4666        return [self.max_digits, self.decimal_places]
4667
4668    def db_value(self, value):
4669        D = decimal.Decimal
4670        if not value:
4671            return value if value is None else D(0)
4672        if self.auto_round:
4673            decimal_value = D(text_type(value))
4674            return decimal_value.quantize(self._exp, rounding=self.rounding)
4675        return value
4676
4677    def python_value(self, value):
4678        if value is not None:
4679            if isinstance(value, decimal.Decimal):
4680                return value
4681            return decimal.Decimal(text_type(value))
4682
4683
4684class _StringField(Field):
4685    def adapt(self, value):
4686        if isinstance(value, text_type):
4687            return value
4688        elif isinstance(value, bytes_type):
4689            return value.decode('utf-8')
4690        return text_type(value)
4691
4692    def __add__(self, other): return StringExpression(self, OP.CONCAT, other)
4693    def __radd__(self, other): return StringExpression(other, OP.CONCAT, self)
4694
4695
4696class CharField(_StringField):
4697    field_type = 'VARCHAR'
4698
4699    def __init__(self, max_length=255, *args, **kwargs):
4700        self.max_length = max_length
4701        super(CharField, self).__init__(*args, **kwargs)
4702
4703    def get_modifiers(self):
4704        return self.max_length and [self.max_length] or None
4705
4706
4707class FixedCharField(CharField):
4708    field_type = 'CHAR'
4709
4710    def python_value(self, value):
4711        value = super(FixedCharField, self).python_value(value)
4712        if value:
4713            value = value.strip()
4714        return value
4715
4716
4717class TextField(_StringField):
4718    field_type = 'TEXT'
4719
4720
4721class BlobField(Field):
4722    field_type = 'BLOB'
4723
4724    def _db_hook(self, database):
4725        if database is None:
4726            self._constructor = bytearray
4727        else:
4728            self._constructor = database.get_binary_type()
4729
4730    def bind(self, model, name, set_attribute=True):
4731        self._constructor = bytearray
4732        if model._meta.database:
4733            if isinstance(model._meta.database, Proxy):
4734                model._meta.database.attach_callback(self._db_hook)
4735            else:
4736                self._db_hook(model._meta.database)
4737
4738        # Attach a hook to the model metadata; in the event the database is
4739        # changed or set at run-time, we will be sure to apply our callback and
4740        # use the proper data-type for our database driver.
4741        model._meta._db_hooks.append(self._db_hook)
4742        return super(BlobField, self).bind(model, name, set_attribute)
4743
4744    def db_value(self, value):
4745        if isinstance(value, text_type):
4746            value = value.encode('raw_unicode_escape')
4747        if isinstance(value, bytes_type):
4748            return self._constructor(value)
4749        return value
4750
4751
4752class BitField(BitwiseMixin, BigIntegerField):
4753    def __init__(self, *args, **kwargs):
4754        kwargs.setdefault('default', 0)
4755        super(BitField, self).__init__(*args, **kwargs)
4756        self.__current_flag = 1
4757
4758    def flag(self, value=None):
4759        if value is None:
4760            value = self.__current_flag
4761            self.__current_flag <<= 1
4762        else:
4763            self.__current_flag = value << 1
4764
4765        class FlagDescriptor(ColumnBase):
4766            def __init__(self, field, value):
4767                self._field = field
4768                self._value = value
4769                super(FlagDescriptor, self).__init__()
4770            def clear(self):
4771                return self._field.bin_and(~self._value)
4772            def set(self):
4773                return self._field.bin_or(self._value)
4774            def __get__(self, instance, instance_type=None):
4775                if instance is None:
4776                    return self
4777                value = getattr(instance, self._field.name) or 0
4778                return (value & self._value) != 0
4779            def __set__(self, instance, is_set):
4780                if is_set not in (True, False):
4781                    raise ValueError('Value must be either True or False')
4782                value = getattr(instance, self._field.name) or 0
4783                if is_set:
4784                    value |= self._value
4785                else:
4786                    value &= ~self._value
4787                setattr(instance, self._field.name, value)
4788            def __sql__(self, ctx):
4789                return ctx.sql(self._field.bin_and(self._value) != 0)
4790        return FlagDescriptor(self, value)
4791
4792
4793class BigBitFieldData(object):
4794    def __init__(self, instance, name):
4795        self.instance = instance
4796        self.name = name
4797        value = self.instance.__data__.get(self.name)
4798        if not value:
4799            value = bytearray()
4800        elif not isinstance(value, bytearray):
4801            value = bytearray(value)
4802        self._buffer = self.instance.__data__[self.name] = value
4803
4804    def _ensure_length(self, idx):
4805        byte_num, byte_offset = divmod(idx, 8)
4806        cur_size = len(self._buffer)
4807        if cur_size <= byte_num:
4808            self._buffer.extend(b'\x00' * ((byte_num + 1) - cur_size))
4809        return byte_num, byte_offset
4810
4811    def set_bit(self, idx):
4812        byte_num, byte_offset = self._ensure_length(idx)
4813        self._buffer[byte_num] |= (1 << byte_offset)
4814
4815    def clear_bit(self, idx):
4816        byte_num, byte_offset = self._ensure_length(idx)
4817        self._buffer[byte_num] &= ~(1 << byte_offset)
4818
4819    def toggle_bit(self, idx):
4820        byte_num, byte_offset = self._ensure_length(idx)
4821        self._buffer[byte_num] ^= (1 << byte_offset)
4822        return bool(self._buffer[byte_num] & (1 << byte_offset))
4823
4824    def is_set(self, idx):
4825        byte_num, byte_offset = self._ensure_length(idx)
4826        return bool(self._buffer[byte_num] & (1 << byte_offset))
4827
4828    def __repr__(self):
4829        return repr(self._buffer)
4830
4831
4832class BigBitFieldAccessor(FieldAccessor):
4833    def __get__(self, instance, instance_type=None):
4834        if instance is None:
4835            return self.field
4836        return BigBitFieldData(instance, self.name)
4837    def __set__(self, instance, value):
4838        if isinstance(value, memoryview):
4839            value = value.tobytes()
4840        elif isinstance(value, buffer_type):
4841            value = bytes(value)
4842        elif isinstance(value, bytearray):
4843            value = bytes_type(value)
4844        elif isinstance(value, BigBitFieldData):
4845            value = bytes_type(value._buffer)
4846        elif isinstance(value, text_type):
4847            value = value.encode('utf-8')
4848        elif not isinstance(value, bytes_type):
4849            raise ValueError('Value must be either a bytes, memoryview or '
4850                             'BigBitFieldData instance.')
4851        super(BigBitFieldAccessor, self).__set__(instance, value)
4852
4853
4854class BigBitField(BlobField):
4855    accessor_class = BigBitFieldAccessor
4856
4857    def __init__(self, *args, **kwargs):
4858        kwargs.setdefault('default', bytes_type)
4859        super(BigBitField, self).__init__(*args, **kwargs)
4860
4861    def db_value(self, value):
4862        return bytes_type(value) if value is not None else value
4863
4864
4865class UUIDField(Field):
4866    field_type = 'UUID'
4867
4868    def db_value(self, value):
4869        if isinstance(value, basestring) and len(value) == 32:
4870            # Hex string. No transformation is necessary.
4871            return value
4872        elif isinstance(value, bytes) and len(value) == 16:
4873            # Allow raw binary representation.
4874            value = uuid.UUID(bytes=value)
4875        if isinstance(value, uuid.UUID):
4876            return value.hex
4877        try:
4878            return uuid.UUID(value).hex
4879        except:
4880            return value
4881
4882    def python_value(self, value):
4883        if isinstance(value, uuid.UUID):
4884            return value
4885        return uuid.UUID(value) if value is not None else None
4886
4887
4888class BinaryUUIDField(BlobField):
4889    field_type = 'UUIDB'
4890
4891    def db_value(self, value):
4892        if isinstance(value, bytes) and len(value) == 16:
4893            # Raw binary value. No transformation is necessary.
4894            return self._constructor(value)
4895        elif isinstance(value, basestring) and len(value) == 32:
4896            # Allow hex string representation.
4897            value = uuid.UUID(hex=value)
4898        if isinstance(value, uuid.UUID):
4899            return self._constructor(value.bytes)
4900        elif value is not None:
4901            raise ValueError('value for binary UUID field must be UUID(), '
4902                             'a hexadecimal string, or a bytes object.')
4903
4904    def python_value(self, value):
4905        if isinstance(value, uuid.UUID):
4906            return value
4907        elif isinstance(value, memoryview):
4908            value = value.tobytes()
4909        elif value and not isinstance(value, bytes):
4910            value = bytes(value)
4911        return uuid.UUID(bytes=value) if value is not None else None
4912
4913
4914def _date_part(date_part):
4915    def dec(self):
4916        return self.model._meta.database.extract_date(date_part, self)
4917    return dec
4918
4919def format_date_time(value, formats, post_process=None):
4920    post_process = post_process or (lambda x: x)
4921    for fmt in formats:
4922        try:
4923            return post_process(datetime.datetime.strptime(value, fmt))
4924        except ValueError:
4925            pass
4926    return value
4927
4928def simple_date_time(value):
4929    try:
4930        return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
4931    except (TypeError, ValueError):
4932        return value
4933
4934
4935class _BaseFormattedField(Field):
4936    formats = None
4937
4938    def __init__(self, formats=None, *args, **kwargs):
4939        if formats is not None:
4940            self.formats = formats
4941        super(_BaseFormattedField, self).__init__(*args, **kwargs)
4942
4943
4944class DateTimeField(_BaseFormattedField):
4945    field_type = 'DATETIME'
4946    formats = [
4947        '%Y-%m-%d %H:%M:%S.%f',
4948        '%Y-%m-%d %H:%M:%S',
4949        '%Y-%m-%d',
4950    ]
4951
4952    def adapt(self, value):
4953        if value and isinstance(value, basestring):
4954            return format_date_time(value, self.formats)
4955        return value
4956
4957    def to_timestamp(self):
4958        return self.model._meta.database.to_timestamp(self)
4959
4960    def truncate(self, part):
4961        return self.model._meta.database.truncate_date(part, self)
4962
4963    year = property(_date_part('year'))
4964    month = property(_date_part('month'))
4965    day = property(_date_part('day'))
4966    hour = property(_date_part('hour'))
4967    minute = property(_date_part('minute'))
4968    second = property(_date_part('second'))
4969
4970
4971class DateField(_BaseFormattedField):
4972    field_type = 'DATE'
4973    formats = [
4974        '%Y-%m-%d',
4975        '%Y-%m-%d %H:%M:%S',
4976        '%Y-%m-%d %H:%M:%S.%f',
4977    ]
4978
4979    def adapt(self, value):
4980        if value and isinstance(value, basestring):
4981            pp = lambda x: x.date()
4982            return format_date_time(value, self.formats, pp)
4983        elif value and isinstance(value, datetime.datetime):
4984            return value.date()
4985        return value
4986
4987    def to_timestamp(self):
4988        return self.model._meta.database.to_timestamp(self)
4989
4990    def truncate(self, part):
4991        return self.model._meta.database.truncate_date(part, self)
4992
4993    year = property(_date_part('year'))
4994    month = property(_date_part('month'))
4995    day = property(_date_part('day'))
4996
4997
4998class TimeField(_BaseFormattedField):
4999    field_type = 'TIME'
5000    formats = [
5001        '%H:%M:%S.%f',
5002        '%H:%M:%S',
5003        '%H:%M',
5004        '%Y-%m-%d %H:%M:%S.%f',
5005        '%Y-%m-%d %H:%M:%S',
5006    ]
5007
5008    def adapt(self, value):
5009        if value:
5010            if isinstance(value, basestring):
5011                pp = lambda x: x.time()
5012                return format_date_time(value, self.formats, pp)
5013            elif isinstance(value, datetime.datetime):
5014                return value.time()
5015        if value is not None and isinstance(value, datetime.timedelta):
5016            return (datetime.datetime.min + value).time()
5017        return value
5018
5019    hour = property(_date_part('hour'))
5020    minute = property(_date_part('minute'))
5021    second = property(_date_part('second'))
5022
5023
5024def _timestamp_date_part(date_part):
5025    def dec(self):
5026        db = self.model._meta.database
5027        expr = ((self / Value(self.resolution, converter=False))
5028                if self.resolution > 1 else self)
5029        return db.extract_date(date_part, db.from_timestamp(expr))
5030    return dec
5031
5032
5033class TimestampField(BigIntegerField):
5034    # Support second -> microsecond resolution.
5035    valid_resolutions = [10**i for i in range(7)]
5036
5037    def __init__(self, *args, **kwargs):
5038        self.resolution = kwargs.pop('resolution', None)
5039
5040        if not self.resolution:
5041            self.resolution = 1
5042        elif self.resolution in range(2, 7):
5043            self.resolution = 10 ** self.resolution
5044        elif self.resolution not in self.valid_resolutions:
5045            raise ValueError('TimestampField resolution must be one of: %s' %
5046                             ', '.join(str(i) for i in self.valid_resolutions))
5047        self.ticks_to_microsecond = 1000000 // self.resolution
5048
5049        self.utc = kwargs.pop('utc', False) or False
5050        dflt = datetime.datetime.utcnow if self.utc else datetime.datetime.now
5051        kwargs.setdefault('default', dflt)
5052        super(TimestampField, self).__init__(*args, **kwargs)
5053
5054    def local_to_utc(self, dt):
5055        # Convert naive local datetime into naive UTC, e.g.:
5056        # 2019-03-01T12:00:00 (local=US/Central) -> 2019-03-01T18:00:00.
5057        # 2019-05-01T12:00:00 (local=US/Central) -> 2019-05-01T17:00:00.
5058        # 2019-03-01T12:00:00 (local=UTC)        -> 2019-03-01T12:00:00.
5059        return datetime.datetime(*time.gmtime(time.mktime(dt.timetuple()))[:6])
5060
5061    def utc_to_local(self, dt):
5062        # Convert a naive UTC datetime into local time, e.g.:
5063        # 2019-03-01T18:00:00 (local=US/Central) -> 2019-03-01T12:00:00.
5064        # 2019-05-01T17:00:00 (local=US/Central) -> 2019-05-01T12:00:00.
5065        # 2019-03-01T12:00:00 (local=UTC)        -> 2019-03-01T12:00:00.
5066        ts = calendar.timegm(dt.utctimetuple())
5067        return datetime.datetime.fromtimestamp(ts)
5068
5069    def get_timestamp(self, value):
5070        if self.utc:
5071            # If utc-mode is on, then we assume all naive datetimes are in UTC.
5072            return calendar.timegm(value.utctimetuple())
5073        else:
5074            return time.mktime(value.timetuple())
5075
5076    def db_value(self, value):
5077        if value is None:
5078            return
5079
5080        if isinstance(value, datetime.datetime):
5081            pass
5082        elif isinstance(value, datetime.date):
5083            value = datetime.datetime(value.year, value.month, value.day)
5084        else:
5085            return int(round(value * self.resolution))
5086
5087        timestamp = self.get_timestamp(value)
5088        if self.resolution > 1:
5089            timestamp += (value.microsecond * .000001)
5090            timestamp *= self.resolution
5091        return int(round(timestamp))
5092
5093    def python_value(self, value):
5094        if value is not None and isinstance(value, (int, float, long)):
5095            if self.resolution > 1:
5096                value, ticks = divmod(value, self.resolution)
5097                microseconds = int(ticks * self.ticks_to_microsecond)
5098            else:
5099                microseconds = 0
5100
5101            if self.utc:
5102                value = datetime.datetime.utcfromtimestamp(value)
5103            else:
5104                value = datetime.datetime.fromtimestamp(value)
5105
5106            if microseconds:
5107                value = value.replace(microsecond=microseconds)
5108
5109        return value
5110
5111    def from_timestamp(self):
5112        expr = ((self / Value(self.resolution, converter=False))
5113                if self.resolution > 1 else self)
5114        return self.model._meta.database.from_timestamp(expr)
5115
5116    year = property(_timestamp_date_part('year'))
5117    month = property(_timestamp_date_part('month'))
5118    day = property(_timestamp_date_part('day'))
5119    hour = property(_timestamp_date_part('hour'))
5120    minute = property(_timestamp_date_part('minute'))
5121    second = property(_timestamp_date_part('second'))
5122
5123
5124class IPField(BigIntegerField):
5125    def db_value(self, val):
5126        if val is not None:
5127            return struct.unpack('!I', socket.inet_aton(val))[0]
5128
5129    def python_value(self, val):
5130        if val is not None:
5131            return socket.inet_ntoa(struct.pack('!I', val))
5132
5133
5134class BooleanField(Field):
5135    field_type = 'BOOL'
5136    adapt = bool
5137
5138
5139class BareField(Field):
5140    def __init__(self, adapt=None, *args, **kwargs):
5141        super(BareField, self).__init__(*args, **kwargs)
5142        if adapt is not None:
5143            self.adapt = adapt
5144
5145    def ddl_datatype(self, ctx):
5146        return
5147
5148
5149class ForeignKeyField(Field):
5150    accessor_class = ForeignKeyAccessor
5151
5152    def __init__(self, model, field=None, backref=None, on_delete=None,
5153                 on_update=None, deferrable=None, _deferred=None,
5154                 rel_model=None, to_field=None, object_id_name=None,
5155                 lazy_load=True, constraint_name=None, related_name=None,
5156                 *args, **kwargs):
5157        kwargs.setdefault('index', True)
5158
5159        super(ForeignKeyField, self).__init__(*args, **kwargs)
5160
5161        if rel_model is not None:
5162            __deprecated__('"rel_model" has been deprecated in favor of '
5163                           '"model" for ForeignKeyField objects.')
5164            model = rel_model
5165        if to_field is not None:
5166            __deprecated__('"to_field" has been deprecated in favor of '
5167                           '"field" for ForeignKeyField objects.')
5168            field = to_field
5169        if related_name is not None:
5170            __deprecated__('"related_name" has been deprecated in favor of '
5171                           '"backref" for Field objects.')
5172            backref = related_name
5173
5174        self._is_self_reference = model == 'self'
5175        self.rel_model = model
5176        self.rel_field = field
5177        self.declared_backref = backref
5178        self.backref = None
5179        self.on_delete = on_delete
5180        self.on_update = on_update
5181        self.deferrable = deferrable
5182        self.deferred = _deferred
5183        self.object_id_name = object_id_name
5184        self.lazy_load = lazy_load
5185        self.constraint_name = constraint_name
5186
5187    @property
5188    def field_type(self):
5189        if not isinstance(self.rel_field, AutoField):
5190            return self.rel_field.field_type
5191        elif isinstance(self.rel_field, BigAutoField):
5192            return BigIntegerField.field_type
5193        return IntegerField.field_type
5194
5195    def get_modifiers(self):
5196        if not isinstance(self.rel_field, AutoField):
5197            return self.rel_field.get_modifiers()
5198        return super(ForeignKeyField, self).get_modifiers()
5199
5200    def adapt(self, value):
5201        return self.rel_field.adapt(value)
5202
5203    def db_value(self, value):
5204        if isinstance(value, self.rel_model):
5205            value = getattr(value, self.rel_field.name)
5206        return self.rel_field.db_value(value)
5207
5208    def python_value(self, value):
5209        if isinstance(value, self.rel_model):
5210            return value
5211        return self.rel_field.python_value(value)
5212
5213    def bind(self, model, name, set_attribute=True):
5214        if not self.column_name:
5215            self.column_name = name if name.endswith('_id') else name + '_id'
5216        if not self.object_id_name:
5217            self.object_id_name = self.column_name
5218            if self.object_id_name == name:
5219                self.object_id_name += '_id'
5220        elif self.object_id_name == name:
5221            raise ValueError('ForeignKeyField "%s"."%s" specifies an '
5222                             'object_id_name that conflicts with its field '
5223                             'name.' % (model._meta.name, name))
5224        if self._is_self_reference:
5225            self.rel_model = model
5226        if isinstance(self.rel_field, basestring):
5227            self.rel_field = getattr(self.rel_model, self.rel_field)
5228        elif self.rel_field is None:
5229            self.rel_field = self.rel_model._meta.primary_key
5230
5231        # Bind field before assigning backref, so field is bound when
5232        # calling declared_backref() (if callable).
5233        super(ForeignKeyField, self).bind(model, name, set_attribute)
5234        self.safe_name = self.object_id_name
5235
5236        if callable_(self.declared_backref):
5237            self.backref = self.declared_backref(self)
5238        else:
5239            self.backref, self.declared_backref = self.declared_backref, None
5240        if not self.backref:
5241            self.backref = '%s_set' % model._meta.name
5242
5243        if set_attribute:
5244            setattr(model, self.object_id_name, ObjectIdAccessor(self))
5245            if self.backref not in '!+':
5246                setattr(self.rel_model, self.backref, BackrefAccessor(self))
5247
5248    def foreign_key_constraint(self):
5249        parts = []
5250        if self.constraint_name:
5251            parts.extend((SQL('CONSTRAINT'), Entity(self.constraint_name)))
5252        parts.extend([
5253            SQL('FOREIGN KEY'),
5254            EnclosedNodeList((self,)),
5255            SQL('REFERENCES'),
5256            self.rel_model,
5257            EnclosedNodeList((self.rel_field,))])
5258        if self.on_delete:
5259            parts.append(SQL('ON DELETE %s' % self.on_delete))
5260        if self.on_update:
5261            parts.append(SQL('ON UPDATE %s' % self.on_update))
5262        if self.deferrable:
5263            parts.append(SQL('DEFERRABLE %s' % self.deferrable))
5264        return NodeList(parts)
5265
5266    def __getattr__(self, attr):
5267        if attr.startswith('__'):
5268            # Prevent recursion error when deep-copying.
5269            raise AttributeError('Cannot look-up non-existant "__" methods.')
5270        if attr in self.rel_model._meta.fields:
5271            return self.rel_model._meta.fields[attr]
5272        raise AttributeError('Foreign-key has no attribute %s, nor is it a '
5273                             'valid field on the related model.' % attr)
5274
5275
5276class DeferredForeignKey(Field):
5277    _unresolved = set()
5278
5279    def __init__(self, rel_model_name, **kwargs):
5280        self.field_kwargs = kwargs
5281        self.rel_model_name = rel_model_name.lower()
5282        DeferredForeignKey._unresolved.add(self)
5283        super(DeferredForeignKey, self).__init__(
5284            column_name=kwargs.get('column_name'),
5285            null=kwargs.get('null'))
5286
5287    __hash__ = object.__hash__
5288
5289    def __deepcopy__(self, memo=None):
5290        return DeferredForeignKey(self.rel_model_name, **self.field_kwargs)
5291
5292    def set_model(self, rel_model):
5293        field = ForeignKeyField(rel_model, _deferred=True, **self.field_kwargs)
5294        self.model._meta.add_field(self.name, field)
5295
5296    @staticmethod
5297    def resolve(model_cls):
5298        unresolved = sorted(DeferredForeignKey._unresolved,
5299                            key=operator.attrgetter('_order'))
5300        for dr in unresolved:
5301            if dr.rel_model_name == model_cls.__name__.lower():
5302                dr.set_model(model_cls)
5303                DeferredForeignKey._unresolved.discard(dr)
5304
5305
5306class DeferredThroughModel(object):
5307    def __init__(self):
5308        self._refs = []
5309
5310    def set_field(self, model, field, name):
5311        self._refs.append((model, field, name))
5312
5313    def set_model(self, through_model):
5314        for src_model, m2mfield, name in self._refs:
5315            m2mfield.through_model = through_model
5316            src_model._meta.add_field(name, m2mfield)
5317
5318
5319class MetaField(Field):
5320    column_name = default = model = name = None
5321    primary_key = False
5322
5323
5324class ManyToManyFieldAccessor(FieldAccessor):
5325    def __init__(self, model, field, name):
5326        super(ManyToManyFieldAccessor, self).__init__(model, field, name)
5327        self.model = field.model
5328        self.rel_model = field.rel_model
5329        self.through_model = field.through_model
5330        src_fks = self.through_model._meta.model_refs[self.model]
5331        dest_fks = self.through_model._meta.model_refs[self.rel_model]
5332        if not src_fks:
5333            raise ValueError('Cannot find foreign-key to "%s" on "%s" model.' %
5334                             (self.model, self.through_model))
5335        elif not dest_fks:
5336            raise ValueError('Cannot find foreign-key to "%s" on "%s" model.' %
5337                             (self.rel_model, self.through_model))
5338        self.src_fk = src_fks[0]
5339        self.dest_fk = dest_fks[0]
5340
5341    def __get__(self, instance, instance_type=None, force_query=False):
5342        if instance is not None:
5343            if not force_query and self.src_fk.backref != '+':
5344                backref = getattr(instance, self.src_fk.backref)
5345                if isinstance(backref, list):
5346                    return [getattr(obj, self.dest_fk.name) for obj in backref]
5347
5348            src_id = getattr(instance, self.src_fk.rel_field.name)
5349            return (ManyToManyQuery(instance, self, self.rel_model)
5350                    .join(self.through_model)
5351                    .join(self.model)
5352                    .where(self.src_fk == src_id))
5353
5354        return self.field
5355
5356    def __set__(self, instance, value):
5357        query = self.__get__(instance, force_query=True)
5358        query.add(value, clear_existing=True)
5359
5360
5361class ManyToManyField(MetaField):
5362    accessor_class = ManyToManyFieldAccessor
5363
5364    def __init__(self, model, backref=None, through_model=None, on_delete=None,
5365                 on_update=None, _is_backref=False):
5366        if through_model is not None:
5367            if not (isinstance(through_model, DeferredThroughModel) or
5368                    is_model(through_model)):
5369                raise TypeError('Unexpected value for through_model. Expected '
5370                                'Model or DeferredThroughModel.')
5371            if not _is_backref and (on_delete is not None or on_update is not None):
5372                raise ValueError('Cannot specify on_delete or on_update when '
5373                                 'through_model is specified.')
5374        self.rel_model = model
5375        self.backref = backref
5376        self._through_model = through_model
5377        self._on_delete = on_delete
5378        self._on_update = on_update
5379        self._is_backref = _is_backref
5380
5381    def _get_descriptor(self):
5382        return ManyToManyFieldAccessor(self)
5383
5384    def bind(self, model, name, set_attribute=True):
5385        if isinstance(self._through_model, DeferredThroughModel):
5386            self._through_model.set_field(model, self, name)
5387            return
5388
5389        super(ManyToManyField, self).bind(model, name, set_attribute)
5390
5391        if not self._is_backref:
5392            many_to_many_field = ManyToManyField(
5393                self.model,
5394                backref=name,
5395                through_model=self.through_model,
5396                on_delete=self._on_delete,
5397                on_update=self._on_update,
5398                _is_backref=True)
5399            self.backref = self.backref or model._meta.name + 's'
5400            self.rel_model._meta.add_field(self.backref, many_to_many_field)
5401
5402    def get_models(self):
5403        return [model for _, model in sorted((
5404            (self._is_backref, self.model),
5405            (not self._is_backref, self.rel_model)))]
5406
5407    @property
5408    def through_model(self):
5409        if self._through_model is None:
5410            self._through_model = self._create_through_model()
5411        return self._through_model
5412
5413    @through_model.setter
5414    def through_model(self, value):
5415        self._through_model = value
5416
5417    def _create_through_model(self):
5418        lhs, rhs = self.get_models()
5419        tables = [model._meta.table_name for model in (lhs, rhs)]
5420
5421        class Meta:
5422            database = self.model._meta.database
5423            schema = self.model._meta.schema
5424            table_name = '%s_%s_through' % tuple(tables)
5425            indexes = (
5426                ((lhs._meta.name, rhs._meta.name),
5427                 True),)
5428
5429        params = {'on_delete': self._on_delete, 'on_update': self._on_update}
5430        attrs = {
5431            lhs._meta.name: ForeignKeyField(lhs, **params),
5432            rhs._meta.name: ForeignKeyField(rhs, **params),
5433            'Meta': Meta}
5434
5435        klass_name = '%s%sThrough' % (lhs.__name__, rhs.__name__)
5436        return type(klass_name, (Model,), attrs)
5437
5438    def get_through_model(self):
5439        # XXX: Deprecated. Just use the "through_model" property.
5440        return self.through_model
5441
5442
5443class VirtualField(MetaField):
5444    field_class = None
5445
5446    def __init__(self, field_class=None, *args, **kwargs):
5447        Field = field_class if field_class is not None else self.field_class
5448        self.field_instance = Field() if Field is not None else None
5449        super(VirtualField, self).__init__(*args, **kwargs)
5450
5451    def db_value(self, value):
5452        if self.field_instance is not None:
5453            return self.field_instance.db_value(value)
5454        return value
5455
5456    def python_value(self, value):
5457        if self.field_instance is not None:
5458            return self.field_instance.python_value(value)
5459        return value
5460
5461    def bind(self, model, name, set_attribute=True):
5462        self.model = model
5463        self.column_name = self.name = self.safe_name = name
5464        setattr(model, name, self.accessor_class(model, self, name))
5465
5466
5467class CompositeKey(MetaField):
5468    sequence = None
5469
5470    def __init__(self, *field_names):
5471        self.field_names = field_names
5472        self._safe_field_names = None
5473
5474    @property
5475    def safe_field_names(self):
5476        if self._safe_field_names is None:
5477            if self.model is None:
5478                return self.field_names
5479
5480            self._safe_field_names = [self.model._meta.fields[f].safe_name
5481                                      for f in self.field_names]
5482        return self._safe_field_names
5483
5484    def __get__(self, instance, instance_type=None):
5485        if instance is not None:
5486            return tuple([getattr(instance, f) for f in self.safe_field_names])
5487        return self
5488
5489    def __set__(self, instance, value):
5490        if not isinstance(value, (list, tuple)):
5491            raise TypeError('A list or tuple must be used to set the value of '
5492                            'a composite primary key.')
5493        if len(value) != len(self.field_names):
5494            raise ValueError('The length of the value must equal the number '
5495                             'of columns of the composite primary key.')
5496        for idx, field_value in enumerate(value):
5497            setattr(instance, self.field_names[idx], field_value)
5498
5499    def __eq__(self, other):
5500        expressions = [(self.model._meta.fields[field] == value)
5501                       for field, value in zip(self.field_names, other)]
5502        return reduce(operator.and_, expressions)
5503
5504    def __ne__(self, other):
5505        return ~(self == other)
5506
5507    def __hash__(self):
5508        return hash((self.model.__name__, self.field_names))
5509
5510    def __sql__(self, ctx):
5511        # If the composite PK is being selected, do not use parens. Elsewhere,
5512        # such as in an expression, we want to use parentheses and treat it as
5513        # a row value.
5514        parens = ctx.scope != SCOPE_SOURCE
5515        return ctx.sql(NodeList([self.model._meta.fields[field]
5516                                 for field in self.field_names], ', ', parens))
5517
5518    def bind(self, model, name, set_attribute=True):
5519        self.model = model
5520        self.column_name = self.name = self.safe_name = name
5521        setattr(model, self.name, self)
5522
5523
5524class _SortedFieldList(object):
5525    __slots__ = ('_keys', '_items')
5526
5527    def __init__(self):
5528        self._keys = []
5529        self._items = []
5530
5531    def __getitem__(self, i):
5532        return self._items[i]
5533
5534    def __iter__(self):
5535        return iter(self._items)
5536
5537    def __contains__(self, item):
5538        k = item._sort_key
5539        i = bisect_left(self._keys, k)
5540        j = bisect_right(self._keys, k)
5541        return item in self._items[i:j]
5542
5543    def index(self, field):
5544        return self._keys.index(field._sort_key)
5545
5546    def insert(self, item):
5547        k = item._sort_key
5548        i = bisect_left(self._keys, k)
5549        self._keys.insert(i, k)
5550        self._items.insert(i, item)
5551
5552    def remove(self, item):
5553        idx = self.index(item)
5554        del self._items[idx]
5555        del self._keys[idx]
5556
5557
5558# MODELS
5559
5560
5561class SchemaManager(object):
5562    def __init__(self, model, database=None, **context_options):
5563        self.model = model
5564        self._database = database
5565        context_options.setdefault('scope', SCOPE_VALUES)
5566        self.context_options = context_options
5567
5568    @property
5569    def database(self):
5570        db = self._database or self.model._meta.database
5571        if db is None:
5572            raise ImproperlyConfigured('database attribute does not appear to '
5573                                       'be set on the model: %s' % self.model)
5574        return db
5575
5576    @database.setter
5577    def database(self, value):
5578        self._database = value
5579
5580    def _create_context(self):
5581        return self.database.get_sql_context(**self.context_options)
5582
5583    def _create_table(self, safe=True, **options):
5584        is_temp = options.pop('temporary', False)
5585        ctx = self._create_context()
5586        ctx.literal('CREATE TEMPORARY TABLE ' if is_temp else 'CREATE TABLE ')
5587        if safe:
5588            ctx.literal('IF NOT EXISTS ')
5589        ctx.sql(self.model).literal(' ')
5590
5591        columns = []
5592        constraints = []
5593        meta = self.model._meta
5594        if meta.composite_key:
5595            pk_columns = [meta.fields[field_name].column
5596                          for field_name in meta.primary_key.field_names]
5597            constraints.append(NodeList((SQL('PRIMARY KEY'),
5598                                         EnclosedNodeList(pk_columns))))
5599
5600        for field in meta.sorted_fields:
5601            columns.append(field.ddl(ctx))
5602            if isinstance(field, ForeignKeyField) and not field.deferred:
5603                constraints.append(field.foreign_key_constraint())
5604
5605        if meta.constraints:
5606            constraints.extend(meta.constraints)
5607
5608        constraints.extend(self._create_table_option_sql(options))
5609        ctx.sql(EnclosedNodeList(columns + constraints))
5610
5611        if meta.table_settings is not None:
5612            table_settings = ensure_tuple(meta.table_settings)
5613            for setting in table_settings:
5614                if not isinstance(setting, basestring):
5615                    raise ValueError('table_settings must be strings')
5616                ctx.literal(' ').literal(setting)
5617
5618        if meta.without_rowid:
5619            ctx.literal(' WITHOUT ROWID')
5620        return ctx
5621
5622    def _create_table_option_sql(self, options):
5623        accum = []
5624        options = merge_dict(self.model._meta.options or {}, options)
5625        if not options:
5626            return accum
5627
5628        for key, value in sorted(options.items()):
5629            if not isinstance(value, Node):
5630                if is_model(value):
5631                    value = value._meta.table
5632                else:
5633                    value = SQL(str(value))
5634            accum.append(NodeList((SQL(key), value), glue='='))
5635        return accum
5636
5637    def create_table(self, safe=True, **options):
5638        self.database.execute(self._create_table(safe=safe, **options))
5639
5640    def _create_table_as(self, table_name, query, safe=True, **meta):
5641        ctx = (self._create_context()
5642               .literal('CREATE TEMPORARY TABLE '
5643                        if meta.get('temporary') else 'CREATE TABLE '))
5644        if safe:
5645            ctx.literal('IF NOT EXISTS ')
5646        return (ctx
5647                .sql(Entity(table_name))
5648                .literal(' AS ')
5649                .sql(query))
5650
5651    def create_table_as(self, table_name, query, safe=True, **meta):
5652        ctx = self._create_table_as(table_name, query, safe=safe, **meta)
5653        self.database.execute(ctx)
5654
5655    def _drop_table(self, safe=True, **options):
5656        ctx = (self._create_context()
5657               .literal('DROP TABLE IF EXISTS ' if safe else 'DROP TABLE ')
5658               .sql(self.model))
5659        if options.get('cascade'):
5660            ctx = ctx.literal(' CASCADE')
5661        elif options.get('restrict'):
5662            ctx = ctx.literal(' RESTRICT')
5663        return ctx
5664
5665    def drop_table(self, safe=True, **options):
5666        self.database.execute(self._drop_table(safe=safe, **options))
5667
5668    def _truncate_table(self, restart_identity=False, cascade=False):
5669        db = self.database
5670        if not db.truncate_table:
5671            return (self._create_context()
5672                    .literal('DELETE FROM ').sql(self.model))
5673
5674        ctx = self._create_context().literal('TRUNCATE TABLE ').sql(self.model)
5675        if restart_identity:
5676            ctx = ctx.literal(' RESTART IDENTITY')
5677        if cascade:
5678            ctx = ctx.literal(' CASCADE')
5679        return ctx
5680
5681    def truncate_table(self, restart_identity=False, cascade=False):
5682        self.database.execute(self._truncate_table(restart_identity, cascade))
5683
5684    def _create_indexes(self, safe=True):
5685        return [self._create_index(index, safe)
5686                for index in self.model._meta.fields_to_index()]
5687
5688    def _create_index(self, index, safe=True):
5689        if isinstance(index, Index):
5690            if not self.database.safe_create_index:
5691                index = index.safe(False)
5692            elif index._safe != safe:
5693                index = index.safe(safe)
5694        return self._create_context().sql(index)
5695
5696    def create_indexes(self, safe=True):
5697        for query in self._create_indexes(safe=safe):
5698            self.database.execute(query)
5699
5700    def _drop_indexes(self, safe=True):
5701        return [self._drop_index(index, safe)
5702                for index in self.model._meta.fields_to_index()
5703                if isinstance(index, Index)]
5704
5705    def _drop_index(self, index, safe):
5706        statement = 'DROP INDEX '
5707        if safe and self.database.safe_drop_index:
5708            statement += 'IF EXISTS '
5709        if isinstance(index._table, Table) and index._table._schema:
5710            index_name = Entity(index._table._schema, index._name)
5711        else:
5712            index_name = Entity(index._name)
5713        return (self
5714                ._create_context()
5715                .literal(statement)
5716                .sql(index_name))
5717
5718    def drop_indexes(self, safe=True):
5719        for query in self._drop_indexes(safe=safe):
5720            self.database.execute(query)
5721
5722    def _check_sequences(self, field):
5723        if not field.sequence or not self.database.sequences:
5724            raise ValueError('Sequences are either not supported, or are not '
5725                             'defined for "%s".' % field.name)
5726
5727    def _sequence_for_field(self, field):
5728        if field.model._meta.schema:
5729            return Entity(field.model._meta.schema, field.sequence)
5730        else:
5731            return Entity(field.sequence)
5732
5733    def _create_sequence(self, field):
5734        self._check_sequences(field)
5735        if not self.database.sequence_exists(field.sequence):
5736            return (self
5737                    ._create_context()
5738                    .literal('CREATE SEQUENCE ')
5739                    .sql(self._sequence_for_field(field)))
5740
5741    def create_sequence(self, field):
5742        seq_ctx = self._create_sequence(field)
5743        if seq_ctx is not None:
5744            self.database.execute(seq_ctx)
5745
5746    def _drop_sequence(self, field):
5747        self._check_sequences(field)
5748        if self.database.sequence_exists(field.sequence):
5749            return (self
5750                    ._create_context()
5751                    .literal('DROP SEQUENCE ')
5752                    .sql(self._sequence_for_field(field)))
5753
5754    def drop_sequence(self, field):
5755        seq_ctx = self._drop_sequence(field)
5756        if seq_ctx is not None:
5757            self.database.execute(seq_ctx)
5758
5759    def _create_foreign_key(self, field):
5760        name = 'fk_%s_%s_refs_%s' % (field.model._meta.table_name,
5761                                     field.column_name,
5762                                     field.rel_model._meta.table_name)
5763        return (self
5764                ._create_context()
5765                .literal('ALTER TABLE ')
5766                .sql(field.model)
5767                .literal(' ADD CONSTRAINT ')
5768                .sql(Entity(_truncate_constraint_name(name)))
5769                .literal(' ')
5770                .sql(field.foreign_key_constraint()))
5771
5772    def create_foreign_key(self, field):
5773        self.database.execute(self._create_foreign_key(field))
5774
5775    def create_sequences(self):
5776        if self.database.sequences:
5777            for field in self.model._meta.sorted_fields:
5778                if field.sequence:
5779                    self.create_sequence(field)
5780
5781    def create_all(self, safe=True, **table_options):
5782        self.create_sequences()
5783        self.create_table(safe, **table_options)
5784        self.create_indexes(safe=safe)
5785
5786    def drop_sequences(self):
5787        if self.database.sequences:
5788            for field in self.model._meta.sorted_fields:
5789                if field.sequence:
5790                    self.drop_sequence(field)
5791
5792    def drop_all(self, safe=True, drop_sequences=True, **options):
5793        self.drop_table(safe, **options)
5794        if drop_sequences:
5795            self.drop_sequences()
5796
5797
5798class Metadata(object):
5799    def __init__(self, model, database=None, table_name=None, indexes=None,
5800                 primary_key=None, constraints=None, schema=None,
5801                 only_save_dirty=False, depends_on=None, options=None,
5802                 db_table=None, table_function=None, table_settings=None,
5803                 without_rowid=False, temporary=False, legacy_table_names=True,
5804                 **kwargs):
5805        if db_table is not None:
5806            __deprecated__('"db_table" has been deprecated in favor of '
5807                           '"table_name" for Models.')
5808            table_name = db_table
5809        self.model = model
5810        self.database = database
5811
5812        self.fields = {}
5813        self.columns = {}
5814        self.combined = {}
5815
5816        self._sorted_field_list = _SortedFieldList()
5817        self.sorted_fields = []
5818        self.sorted_field_names = []
5819
5820        self.defaults = {}
5821        self._default_by_name = {}
5822        self._default_dict = {}
5823        self._default_callables = {}
5824        self._default_callable_list = []
5825
5826        self.name = model.__name__.lower()
5827        self.table_function = table_function
5828        self.legacy_table_names = legacy_table_names
5829        if not table_name:
5830            table_name = (self.table_function(model)
5831                          if self.table_function
5832                          else self.make_table_name())
5833        self.table_name = table_name
5834        self._table = None
5835
5836        self.indexes = list(indexes) if indexes else []
5837        self.constraints = constraints
5838        self._schema = schema
5839        self.primary_key = primary_key
5840        self.composite_key = self.auto_increment = None
5841        self.only_save_dirty = only_save_dirty
5842        self.depends_on = depends_on
5843        self.table_settings = table_settings
5844        self.without_rowid = without_rowid
5845        self.temporary = temporary
5846
5847        self.refs = {}
5848        self.backrefs = {}
5849        self.model_refs = collections.defaultdict(list)
5850        self.model_backrefs = collections.defaultdict(list)
5851        self.manytomany = {}
5852
5853        self.options = options or {}
5854        for key, value in kwargs.items():
5855            setattr(self, key, value)
5856        self._additional_keys = set(kwargs.keys())
5857
5858        # Allow objects to register hooks that are called if the model is bound
5859        # to a different database. For example, BlobField uses a different
5860        # Python data-type depending on the db driver / python version. When
5861        # the database changes, we need to update any BlobField so they can use
5862        # the appropriate data-type.
5863        self._db_hooks = []
5864
5865    def make_table_name(self):
5866        if self.legacy_table_names:
5867            return re.sub(r'[^\w]+', '_', self.name)
5868        return make_snake_case(self.model.__name__)
5869
5870    def model_graph(self, refs=True, backrefs=True, depth_first=True):
5871        if not refs and not backrefs:
5872            raise ValueError('One of `refs` or `backrefs` must be True.')
5873
5874        accum = [(None, self.model, None)]
5875        seen = set()
5876        queue = collections.deque((self,))
5877        method = queue.pop if depth_first else queue.popleft
5878
5879        while queue:
5880            curr = method()
5881            if curr in seen: continue
5882            seen.add(curr)
5883
5884            if refs:
5885                for fk, model in curr.refs.items():
5886                    accum.append((fk, model, False))
5887                    queue.append(model._meta)
5888            if backrefs:
5889                for fk, model in curr.backrefs.items():
5890                    accum.append((fk, model, True))
5891                    queue.append(model._meta)
5892
5893        return accum
5894
5895    def add_ref(self, field):
5896        rel = field.rel_model
5897        self.refs[field] = rel
5898        self.model_refs[rel].append(field)
5899        rel._meta.backrefs[field] = self.model
5900        rel._meta.model_backrefs[self.model].append(field)
5901
5902    def remove_ref(self, field):
5903        rel = field.rel_model
5904        del self.refs[field]
5905        self.model_refs[rel].remove(field)
5906        del rel._meta.backrefs[field]
5907        rel._meta.model_backrefs[self.model].remove(field)
5908
5909    def add_manytomany(self, field):
5910        self.manytomany[field.name] = field
5911
5912    def remove_manytomany(self, field):
5913        del self.manytomany[field.name]
5914
5915    @property
5916    def table(self):
5917        if self._table is None:
5918            self._table = Table(
5919                self.table_name,
5920                [field.column_name for field in self.sorted_fields],
5921                schema=self.schema,
5922                _model=self.model,
5923                _database=self.database)
5924        return self._table
5925
5926    @table.setter
5927    def table(self, value):
5928        raise AttributeError('Cannot set the "table".')
5929
5930    @table.deleter
5931    def table(self):
5932        self._table = None
5933
5934    @property
5935    def schema(self):
5936        return self._schema
5937
5938    @schema.setter
5939    def schema(self, value):
5940        self._schema = value
5941        del self.table
5942
5943    @property
5944    def entity(self):
5945        if self._schema:
5946            return Entity(self._schema, self.table_name)
5947        else:
5948            return Entity(self.table_name)
5949
5950    def _update_sorted_fields(self):
5951        self.sorted_fields = list(self._sorted_field_list)
5952        self.sorted_field_names = [f.name for f in self.sorted_fields]
5953
5954    def get_rel_for_model(self, model):
5955        if isinstance(model, ModelAlias):
5956            model = model.model
5957        forwardrefs = self.model_refs.get(model, [])
5958        backrefs = self.model_backrefs.get(model, [])
5959        return (forwardrefs, backrefs)
5960
5961    def add_field(self, field_name, field, set_attribute=True):
5962        if field_name in self.fields:
5963            self.remove_field(field_name)
5964        elif field_name in self.manytomany:
5965            self.remove_manytomany(self.manytomany[field_name])
5966
5967        if not isinstance(field, MetaField):
5968            del self.table
5969            field.bind(self.model, field_name, set_attribute)
5970            self.fields[field.name] = field
5971            self.columns[field.column_name] = field
5972            self.combined[field.name] = field
5973            self.combined[field.column_name] = field
5974
5975            self._sorted_field_list.insert(field)
5976            self._update_sorted_fields()
5977
5978            if field.default is not None:
5979                # This optimization helps speed up model instance construction.
5980                self.defaults[field] = field.default
5981                if callable_(field.default):
5982                    self._default_callables[field] = field.default
5983                    self._default_callable_list.append((field.name,
5984                                                        field.default))
5985                else:
5986                    self._default_dict[field] = field.default
5987                    self._default_by_name[field.name] = field.default
5988        else:
5989            field.bind(self.model, field_name, set_attribute)
5990
5991        if isinstance(field, ForeignKeyField):
5992            self.add_ref(field)
5993        elif isinstance(field, ManyToManyField) and field.name:
5994            self.add_manytomany(field)
5995
5996    def remove_field(self, field_name):
5997        if field_name not in self.fields:
5998            return
5999
6000        del self.table
6001        original = self.fields.pop(field_name)
6002        del self.columns[original.column_name]
6003        del self.combined[field_name]
6004        try:
6005            del self.combined[original.column_name]
6006        except KeyError:
6007            pass
6008        self._sorted_field_list.remove(original)
6009        self._update_sorted_fields()
6010
6011        if original.default is not None:
6012            del self.defaults[original]
6013            if self._default_callables.pop(original, None):
6014                for i, (name, _) in enumerate(self._default_callable_list):
6015                    if name == field_name:
6016                        self._default_callable_list.pop(i)
6017                        break
6018            else:
6019                self._default_dict.pop(original, None)
6020                self._default_by_name.pop(original.name, None)
6021
6022        if isinstance(original, ForeignKeyField):
6023            self.remove_ref(original)
6024
6025    def set_primary_key(self, name, field):
6026        self.composite_key = isinstance(field, CompositeKey)
6027        self.add_field(name, field)
6028        self.primary_key = field
6029        self.auto_increment = (
6030            field.auto_increment or
6031            bool(field.sequence))
6032
6033    def get_primary_keys(self):
6034        if self.composite_key:
6035            return tuple([self.fields[field_name]
6036                          for field_name in self.primary_key.field_names])
6037        else:
6038            return (self.primary_key,) if self.primary_key is not False else ()
6039
6040    def get_default_dict(self):
6041        dd = self._default_by_name.copy()
6042        for field_name, default in self._default_callable_list:
6043            dd[field_name] = default()
6044        return dd
6045
6046    def fields_to_index(self):
6047        indexes = []
6048        for f in self.sorted_fields:
6049            if f.primary_key:
6050                continue
6051            if f.index or f.unique:
6052                indexes.append(ModelIndex(self.model, (f,), unique=f.unique,
6053                                          using=f.index_type))
6054
6055        for index_obj in self.indexes:
6056            if isinstance(index_obj, Node):
6057                indexes.append(index_obj)
6058            elif isinstance(index_obj, (list, tuple)):
6059                index_parts, unique = index_obj
6060                fields = []
6061                for part in index_parts:
6062                    if isinstance(part, basestring):
6063                        fields.append(self.combined[part])
6064                    elif isinstance(part, Node):
6065                        fields.append(part)
6066                    else:
6067                        raise ValueError('Expected either a field name or a '
6068                                         'subclass of Node. Got: %s' % part)
6069                indexes.append(ModelIndex(self.model, fields, unique=unique))
6070
6071        return indexes
6072
6073    def set_database(self, database):
6074        self.database = database
6075        self.model._schema._database = database
6076        del self.table
6077
6078        # Apply any hooks that have been registered.
6079        for hook in self._db_hooks:
6080            hook(database)
6081
6082    def set_table_name(self, table_name):
6083        self.table_name = table_name
6084        del self.table
6085
6086
6087class SubclassAwareMetadata(Metadata):
6088    models = []
6089
6090    def __init__(self, model, *args, **kwargs):
6091        super(SubclassAwareMetadata, self).__init__(model, *args, **kwargs)
6092        self.models.append(model)
6093
6094    def map_models(self, fn):
6095        for model in self.models:
6096            fn(model)
6097
6098
6099class DoesNotExist(Exception): pass
6100
6101
6102class ModelBase(type):
6103    inheritable = set(['constraints', 'database', 'indexes', 'primary_key',
6104                       'options', 'schema', 'table_function', 'temporary',
6105                       'only_save_dirty', 'legacy_table_names',
6106                       'table_settings'])
6107
6108    def __new__(cls, name, bases, attrs):
6109        if name == MODEL_BASE or bases[0].__name__ == MODEL_BASE:
6110            return super(ModelBase, cls).__new__(cls, name, bases, attrs)
6111
6112        meta_options = {}
6113        meta = attrs.pop('Meta', None)
6114        if meta:
6115            for k, v in meta.__dict__.items():
6116                if not k.startswith('_'):
6117                    meta_options[k] = v
6118
6119        pk = getattr(meta, 'primary_key', None)
6120        pk_name = parent_pk = None
6121
6122        # Inherit any field descriptors by deep copying the underlying field
6123        # into the attrs of the new model, additionally see if the bases define
6124        # inheritable model options and swipe them.
6125        for b in bases:
6126            if not hasattr(b, '_meta'):
6127                continue
6128
6129            base_meta = b._meta
6130            if parent_pk is None:
6131                parent_pk = deepcopy(base_meta.primary_key)
6132            all_inheritable = cls.inheritable | base_meta._additional_keys
6133            for k in base_meta.__dict__:
6134                if k in all_inheritable and k not in meta_options:
6135                    meta_options[k] = base_meta.__dict__[k]
6136            meta_options.setdefault('schema', base_meta.schema)
6137
6138            for (k, v) in b.__dict__.items():
6139                if k in attrs: continue
6140
6141                if isinstance(v, FieldAccessor) and not v.field.primary_key:
6142                    attrs[k] = deepcopy(v.field)
6143
6144        sopts = meta_options.pop('schema_options', None) or {}
6145        Meta = meta_options.get('model_metadata_class', Metadata)
6146        Schema = meta_options.get('schema_manager_class', SchemaManager)
6147
6148        # Construct the new class.
6149        cls = super(ModelBase, cls).__new__(cls, name, bases, attrs)
6150        cls.__data__ = cls.__rel__ = None
6151
6152        cls._meta = Meta(cls, **meta_options)
6153        cls._schema = Schema(cls, **sopts)
6154
6155        fields = []
6156        for key, value in cls.__dict__.items():
6157            if isinstance(value, Field):
6158                if value.primary_key and pk:
6159                    raise ValueError('over-determined primary key %s.' % name)
6160                elif value.primary_key:
6161                    pk, pk_name = value, key
6162                else:
6163                    fields.append((key, value))
6164
6165        if pk is None:
6166            if parent_pk is not False:
6167                pk, pk_name = ((parent_pk, parent_pk.name)
6168                               if parent_pk is not None else
6169                               (AutoField(), 'id'))
6170            else:
6171                pk = False
6172        elif isinstance(pk, CompositeKey):
6173            pk_name = '__composite_key__'
6174            cls._meta.composite_key = True
6175
6176        if pk is not False:
6177            cls._meta.set_primary_key(pk_name, pk)
6178
6179        for name, field in fields:
6180            cls._meta.add_field(name, field)
6181
6182        # Create a repr and error class before finalizing.
6183        if hasattr(cls, '__str__') and '__repr__' not in attrs:
6184            setattr(cls, '__repr__', lambda self: '<%s: %s>' % (
6185                cls.__name__, self.__str__()))
6186
6187        exc_name = '%sDoesNotExist' % cls.__name__
6188        exc_attrs = {'__module__': cls.__module__}
6189        exception_class = type(exc_name, (DoesNotExist,), exc_attrs)
6190        cls.DoesNotExist = exception_class
6191
6192        # Call validation hook, allowing additional model validation.
6193        cls.validate_model()
6194        DeferredForeignKey.resolve(cls)
6195        return cls
6196
6197    def __repr__(self):
6198        return '<Model: %s>' % self.__name__
6199
6200    def __iter__(self):
6201        return iter(self.select())
6202
6203    def __getitem__(self, key):
6204        return self.get_by_id(key)
6205
6206    def __setitem__(self, key, value):
6207        self.set_by_id(key, value)
6208
6209    def __delitem__(self, key):
6210        self.delete_by_id(key)
6211
6212    def __contains__(self, key):
6213        try:
6214            self.get_by_id(key)
6215        except self.DoesNotExist:
6216            return False
6217        else:
6218            return True
6219
6220    def __len__(self):
6221        return self.select().count()
6222    def __bool__(self): return True
6223    __nonzero__ = __bool__  # Python 2.
6224
6225    def __sql__(self, ctx):
6226        return ctx.sql(self._meta.table)
6227
6228
6229class _BoundModelsContext(_callable_context_manager):
6230    def __init__(self, models, database, bind_refs, bind_backrefs):
6231        self.models = models
6232        self.database = database
6233        self.bind_refs = bind_refs
6234        self.bind_backrefs = bind_backrefs
6235
6236    def __enter__(self):
6237        self._orig_database = []
6238        for model in self.models:
6239            self._orig_database.append(model._meta.database)
6240            model.bind(self.database, self.bind_refs, self.bind_backrefs,
6241                       _exclude=set(self.models))
6242        return self.models
6243
6244    def __exit__(self, exc_type, exc_val, exc_tb):
6245        for model, db in zip(self.models, self._orig_database):
6246            model.bind(db, self.bind_refs, self.bind_backrefs,
6247                       _exclude=set(self.models))
6248
6249
6250class Model(with_metaclass(ModelBase, Node)):
6251    def __init__(self, *args, **kwargs):
6252        if kwargs.pop('__no_default__', None):
6253            self.__data__ = {}
6254        else:
6255            self.__data__ = self._meta.get_default_dict()
6256        self._dirty = set(self.__data__)
6257        self.__rel__ = {}
6258
6259        for k in kwargs:
6260            setattr(self, k, kwargs[k])
6261
6262    def __str__(self):
6263        return str(self._pk) if self._meta.primary_key is not False else 'n/a'
6264
6265    @classmethod
6266    def validate_model(cls):
6267        pass
6268
6269    @classmethod
6270    def alias(cls, alias=None):
6271        return ModelAlias(cls, alias)
6272
6273    @classmethod
6274    def select(cls, *fields):
6275        is_default = not fields
6276        if not fields:
6277            fields = cls._meta.sorted_fields
6278        return ModelSelect(cls, fields, is_default=is_default)
6279
6280    @classmethod
6281    def _normalize_data(cls, data, kwargs):
6282        normalized = {}
6283        if data:
6284            if not isinstance(data, dict):
6285                if kwargs:
6286                    raise ValueError('Data cannot be mixed with keyword '
6287                                     'arguments: %s' % data)
6288                return data
6289            for key in data:
6290                try:
6291                    field = (key if isinstance(key, Field)
6292                             else cls._meta.combined[key])
6293                except KeyError:
6294                    if not isinstance(key, Node):
6295                        raise ValueError('Unrecognized field name: "%s" in %s.'
6296                                         % (key, data))
6297                    field = key
6298                normalized[field] = data[key]
6299        if kwargs:
6300            for key in kwargs:
6301                try:
6302                    normalized[cls._meta.combined[key]] = kwargs[key]
6303                except KeyError:
6304                    normalized[getattr(cls, key)] = kwargs[key]
6305        return normalized
6306
6307    @classmethod
6308    def update(cls, __data=None, **update):
6309        return ModelUpdate(cls, cls._normalize_data(__data, update))
6310
6311    @classmethod
6312    def insert(cls, __data=None, **insert):
6313        return ModelInsert(cls, cls._normalize_data(__data, insert))
6314
6315    @classmethod
6316    def insert_many(cls, rows, fields=None):
6317        return ModelInsert(cls, insert=rows, columns=fields)
6318
6319    @classmethod
6320    def insert_from(cls, query, fields):
6321        columns = [getattr(cls, field) if isinstance(field, basestring)
6322                   else field for field in fields]
6323        return ModelInsert(cls, insert=query, columns=columns)
6324
6325    @classmethod
6326    def replace(cls, __data=None, **insert):
6327        return cls.insert(__data, **insert).on_conflict('REPLACE')
6328
6329    @classmethod
6330    def replace_many(cls, rows, fields=None):
6331        return (cls
6332                .insert_many(rows=rows, fields=fields)
6333                .on_conflict('REPLACE'))
6334
6335    @classmethod
6336    def raw(cls, sql, *params):
6337        return ModelRaw(cls, sql, params)
6338
6339    @classmethod
6340    def delete(cls):
6341        return ModelDelete(cls)
6342
6343    @classmethod
6344    def create(cls, **query):
6345        inst = cls(**query)
6346        inst.save(force_insert=True)
6347        return inst
6348
6349    @classmethod
6350    def bulk_create(cls, model_list, batch_size=None):
6351        if batch_size is not None:
6352            batches = chunked(model_list, batch_size)
6353        else:
6354            batches = [model_list]
6355
6356        field_names = list(cls._meta.sorted_field_names)
6357        if cls._meta.auto_increment:
6358            pk_name = cls._meta.primary_key.name
6359            field_names.remove(pk_name)
6360
6361        if cls._meta.database.returning_clause and \
6362           cls._meta.primary_key is not False:
6363            pk_fields = cls._meta.get_primary_keys()
6364        else:
6365            pk_fields = None
6366
6367        fields = [cls._meta.fields[field_name] for field_name in field_names]
6368        attrs = []
6369        for field in fields:
6370            if isinstance(field, ForeignKeyField):
6371                attrs.append(field.object_id_name)
6372            else:
6373                attrs.append(field.name)
6374
6375        for batch in batches:
6376            accum = ([getattr(model, f) for f in attrs]
6377                     for model in batch)
6378            res = cls.insert_many(accum, fields=fields).execute()
6379            if pk_fields and res is not None:
6380                for row, model in zip(res, batch):
6381                    for (pk_field, obj_id) in zip(pk_fields, row):
6382                        setattr(model, pk_field.name, obj_id)
6383
6384    @classmethod
6385    def bulk_update(cls, model_list, fields, batch_size=None):
6386        if isinstance(cls._meta.primary_key, CompositeKey):
6387            raise ValueError('bulk_update() is not supported for models with '
6388                             'a composite primary key.')
6389
6390        # First normalize list of fields so all are field instances.
6391        fields = [cls._meta.fields[f] if isinstance(f, basestring) else f
6392                  for f in fields]
6393        # Now collect list of attribute names to use for values.
6394        attrs = [field.object_id_name if isinstance(field, ForeignKeyField)
6395                 else field.name for field in fields]
6396
6397        if batch_size is not None:
6398            batches = chunked(model_list, batch_size)
6399        else:
6400            batches = [model_list]
6401
6402        n = 0
6403        pk = cls._meta.primary_key
6404
6405        for batch in batches:
6406            id_list = [model._pk for model in batch]
6407            update = {}
6408            for field, attr in zip(fields, attrs):
6409                accum = []
6410                for model in batch:
6411                    value = getattr(model, attr)
6412                    if not isinstance(value, Node):
6413                        value = field.to_value(value)
6414                    accum.append((pk.to_value(model._pk), value))
6415                case = Case(pk, accum)
6416                update[field] = case
6417
6418            n += (cls.update(update)
6419                  .where(cls._meta.primary_key.in_(id_list))
6420                  .execute())
6421        return n
6422
6423    @classmethod
6424    def noop(cls):
6425        return NoopModelSelect(cls, ())
6426
6427    @classmethod
6428    def get(cls, *query, **filters):
6429        sq = cls.select()
6430        if query:
6431            # Handle simple lookup using just the primary key.
6432            if len(query) == 1 and isinstance(query[0], int):
6433                sq = sq.where(cls._meta.primary_key == query[0])
6434            else:
6435                sq = sq.where(*query)
6436        if filters:
6437            sq = sq.filter(**filters)
6438        return sq.get()
6439
6440    @classmethod
6441    def get_or_none(cls, *query, **filters):
6442        try:
6443            return cls.get(*query, **filters)
6444        except DoesNotExist:
6445            pass
6446
6447    @classmethod
6448    def get_by_id(cls, pk):
6449        return cls.get(cls._meta.primary_key == pk)
6450
6451    @classmethod
6452    def set_by_id(cls, key, value):
6453        if key is None:
6454            return cls.insert(value).execute()
6455        else:
6456            return (cls.update(value)
6457                    .where(cls._meta.primary_key == key).execute())
6458
6459    @classmethod
6460    def delete_by_id(cls, pk):
6461        return cls.delete().where(cls._meta.primary_key == pk).execute()
6462
6463    @classmethod
6464    def get_or_create(cls, **kwargs):
6465        defaults = kwargs.pop('defaults', {})
6466        query = cls.select()
6467        for field, value in kwargs.items():
6468            query = query.where(getattr(cls, field) == value)
6469
6470        try:
6471            return query.get(), False
6472        except cls.DoesNotExist:
6473            try:
6474                if defaults:
6475                    kwargs.update(defaults)
6476                with cls._meta.database.atomic():
6477                    return cls.create(**kwargs), True
6478            except IntegrityError as exc:
6479                try:
6480                    return query.get(), False
6481                except cls.DoesNotExist:
6482                    raise exc
6483
6484    @classmethod
6485    def filter(cls, *dq_nodes, **filters):
6486        return cls.select().filter(*dq_nodes, **filters)
6487
6488    def get_id(self):
6489        # Using getattr(self, pk-name) could accidentally trigger a query if
6490        # the primary-key is a foreign-key. So we use the safe_name attribute,
6491        # which defaults to the field-name, but will be the object_id_name for
6492        # foreign-key fields.
6493        if self._meta.primary_key is not False:
6494            return getattr(self, self._meta.primary_key.safe_name)
6495
6496    _pk = property(get_id)
6497
6498    @_pk.setter
6499    def _pk(self, value):
6500        setattr(self, self._meta.primary_key.name, value)
6501
6502    def _pk_expr(self):
6503        return self._meta.primary_key == self._pk
6504
6505    def _prune_fields(self, field_dict, only):
6506        new_data = {}
6507        for field in only:
6508            if isinstance(field, basestring):
6509                field = self._meta.combined[field]
6510            if field.name in field_dict:
6511                new_data[field.name] = field_dict[field.name]
6512        return new_data
6513
6514    def _populate_unsaved_relations(self, field_dict):
6515        for foreign_key_field in self._meta.refs:
6516            foreign_key = foreign_key_field.name
6517            conditions = (
6518                foreign_key in field_dict and
6519                field_dict[foreign_key] is None and
6520                self.__rel__.get(foreign_key) is not None)
6521            if conditions:
6522                setattr(self, foreign_key, getattr(self, foreign_key))
6523                field_dict[foreign_key] = self.__data__[foreign_key]
6524
6525    def save(self, force_insert=False, only=None):
6526        field_dict = self.__data__.copy()
6527        if self._meta.primary_key is not False:
6528            pk_field = self._meta.primary_key
6529            pk_value = self._pk
6530        else:
6531            pk_field = pk_value = None
6532        if only is not None:
6533            field_dict = self._prune_fields(field_dict, only)
6534        elif self._meta.only_save_dirty and not force_insert:
6535            field_dict = self._prune_fields(field_dict, self.dirty_fields)
6536            if not field_dict:
6537                self._dirty.clear()
6538                return False
6539
6540        self._populate_unsaved_relations(field_dict)
6541        rows = 1
6542
6543        if self._meta.auto_increment and pk_value is None:
6544            field_dict.pop(pk_field.name, None)
6545
6546        if pk_value is not None and not force_insert:
6547            if self._meta.composite_key:
6548                for pk_part_name in pk_field.field_names:
6549                    field_dict.pop(pk_part_name, None)
6550            else:
6551                field_dict.pop(pk_field.name, None)
6552            if not field_dict:
6553                raise ValueError('no data to save!')
6554            rows = self.update(**field_dict).where(self._pk_expr()).execute()
6555        elif pk_field is not None:
6556            pk = self.insert(**field_dict).execute()
6557            if pk is not None and (self._meta.auto_increment or
6558                                   pk_value is None):
6559                self._pk = pk
6560        else:
6561            self.insert(**field_dict).execute()
6562
6563        self._dirty.clear()
6564        return rows
6565
6566    def is_dirty(self):
6567        return bool(self._dirty)
6568
6569    @property
6570    def dirty_fields(self):
6571        return [f for f in self._meta.sorted_fields if f.name in self._dirty]
6572
6573    def dependencies(self, search_nullable=False):
6574        model_class = type(self)
6575        stack = [(type(self), None)]
6576        seen = set()
6577
6578        while stack:
6579            klass, query = stack.pop()
6580            if klass in seen:
6581                continue
6582            seen.add(klass)
6583            for fk, rel_model in klass._meta.backrefs.items():
6584                if rel_model is model_class or query is None:
6585                    node = (fk == self.__data__[fk.rel_field.name])
6586                else:
6587                    node = fk << query
6588                subquery = (rel_model.select(rel_model._meta.primary_key)
6589                            .where(node))
6590                if not fk.null or search_nullable:
6591                    stack.append((rel_model, subquery))
6592                yield (node, fk)
6593
6594    def delete_instance(self, recursive=False, delete_nullable=False):
6595        if recursive:
6596            dependencies = self.dependencies(delete_nullable)
6597            for query, fk in reversed(list(dependencies)):
6598                model = fk.model
6599                if fk.null and not delete_nullable:
6600                    model.update(**{fk.name: None}).where(query).execute()
6601                else:
6602                    model.delete().where(query).execute()
6603        return type(self).delete().where(self._pk_expr()).execute()
6604
6605    def __hash__(self):
6606        return hash((self.__class__, self._pk))
6607
6608    def __eq__(self, other):
6609        return (
6610            other.__class__ == self.__class__ and
6611            self._pk is not None and
6612            self._pk == other._pk)
6613
6614    def __ne__(self, other):
6615        return not self == other
6616
6617    def __sql__(self, ctx):
6618        # NOTE: when comparing a foreign-key field whose related-field is not a
6619        # primary-key, then doing an equality test for the foreign-key with a
6620        # model instance will return the wrong value; since we would return
6621        # the primary key for a given model instance.
6622        #
6623        # This checks to see if we have a converter in the scope, and that we
6624        # are converting a foreign-key expression. If so, we hand the model
6625        # instance to the converter rather than blindly grabbing the primary-
6626        # key. In the event the provided converter fails to handle the model
6627        # instance, then we will return the primary-key.
6628        if ctx.state.converter is not None and ctx.state.is_fk_expr:
6629            try:
6630                return ctx.sql(Value(self, converter=ctx.state.converter))
6631            except (TypeError, ValueError):
6632                pass
6633
6634        return ctx.sql(Value(getattr(self, self._meta.primary_key.name),
6635                             converter=self._meta.primary_key.db_value))
6636
6637    @classmethod
6638    def bind(cls, database, bind_refs=True, bind_backrefs=True, _exclude=None):
6639        is_different = cls._meta.database is not database
6640        cls._meta.set_database(database)
6641        if bind_refs or bind_backrefs:
6642            if _exclude is None:
6643                _exclude = set()
6644            G = cls._meta.model_graph(refs=bind_refs, backrefs=bind_backrefs)
6645            for _, model, is_backref in G:
6646                if model not in _exclude:
6647                    model._meta.set_database(database)
6648                    _exclude.add(model)
6649        return is_different
6650
6651    @classmethod
6652    def bind_ctx(cls, database, bind_refs=True, bind_backrefs=True):
6653        return _BoundModelsContext((cls,), database, bind_refs, bind_backrefs)
6654
6655    @classmethod
6656    def table_exists(cls):
6657        M = cls._meta
6658        return cls._schema.database.table_exists(M.table.__name__, M.schema)
6659
6660    @classmethod
6661    def create_table(cls, safe=True, **options):
6662        if 'fail_silently' in options:
6663            __deprecated__('"fail_silently" has been deprecated in favor of '
6664                           '"safe" for the create_table() method.')
6665            safe = options.pop('fail_silently')
6666
6667        if safe and not cls._schema.database.safe_create_index \
6668           and cls.table_exists():
6669            return
6670        if cls._meta.temporary:
6671            options.setdefault('temporary', cls._meta.temporary)
6672        cls._schema.create_all(safe, **options)
6673
6674    @classmethod
6675    def drop_table(cls, safe=True, drop_sequences=True, **options):
6676        if safe and not cls._schema.database.safe_drop_index \
6677           and not cls.table_exists():
6678            return
6679        if cls._meta.temporary:
6680            options.setdefault('temporary', cls._meta.temporary)
6681        cls._schema.drop_all(safe, drop_sequences, **options)
6682
6683    @classmethod
6684    def truncate_table(cls, **options):
6685        cls._schema.truncate_table(**options)
6686
6687    @classmethod
6688    def index(cls, *fields, **kwargs):
6689        return ModelIndex(cls, fields, **kwargs)
6690
6691    @classmethod
6692    def add_index(cls, *fields, **kwargs):
6693        if len(fields) == 1 and isinstance(fields[0], (SQL, Index)):
6694            cls._meta.indexes.append(fields[0])
6695        else:
6696            cls._meta.indexes.append(ModelIndex(cls, fields, **kwargs))
6697
6698
6699class ModelAlias(Node):
6700    """Provide a separate reference to a model in a query."""
6701    def __init__(self, model, alias=None):
6702        self.__dict__['model'] = model
6703        self.__dict__['alias'] = alias
6704
6705    def __getattr__(self, attr):
6706        # Hack to work-around the fact that properties or other objects
6707        # implementing the descriptor protocol (on the model being aliased),
6708        # will not work correctly when we use getattr(). So we explicitly pass
6709        # the model alias to the descriptor's getter.
6710        try:
6711            obj = self.model.__dict__[attr]
6712        except KeyError:
6713            pass
6714        else:
6715            if isinstance(obj, ModelDescriptor):
6716                return obj.__get__(None, self)
6717
6718        model_attr = getattr(self.model, attr)
6719        if isinstance(model_attr, Field):
6720            self.__dict__[attr] = FieldAlias.create(self, model_attr)
6721            return self.__dict__[attr]
6722        return model_attr
6723
6724    def __setattr__(self, attr, value):
6725        raise AttributeError('Cannot set attributes on model aliases.')
6726
6727    def get_field_aliases(self):
6728        return [getattr(self, n) for n in self.model._meta.sorted_field_names]
6729
6730    def select(self, *selection):
6731        if not selection:
6732            selection = self.get_field_aliases()
6733        return ModelSelect(self, selection)
6734
6735    def __call__(self, **kwargs):
6736        return self.model(**kwargs)
6737
6738    def __sql__(self, ctx):
6739        if ctx.scope == SCOPE_VALUES:
6740            # Return the quoted table name.
6741            return ctx.sql(self.model)
6742
6743        if self.alias:
6744            ctx.alias_manager[self] = self.alias
6745
6746        if ctx.scope == SCOPE_SOURCE:
6747            # Define the table and its alias.
6748            return (ctx
6749                    .sql(self.model._meta.entity)
6750                    .literal(' AS ')
6751                    .sql(Entity(ctx.alias_manager[self])))
6752        else:
6753            # Refer to the table using the alias.
6754            return ctx.sql(Entity(ctx.alias_manager[self]))
6755
6756
6757class FieldAlias(Field):
6758    def __init__(self, source, field):
6759        self.source = source
6760        self.model = source.model
6761        self.field = field
6762
6763    @classmethod
6764    def create(cls, source, field):
6765        class _FieldAlias(cls, type(field)):
6766            pass
6767        return _FieldAlias(source, field)
6768
6769    def clone(self):
6770        return FieldAlias(self.source, self.field)
6771
6772    def adapt(self, value): return self.field.adapt(value)
6773    def python_value(self, value): return self.field.python_value(value)
6774    def db_value(self, value): return self.field.db_value(value)
6775    def __getattr__(self, attr):
6776        return self.source if attr == 'model' else getattr(self.field, attr)
6777
6778    def __sql__(self, ctx):
6779        return ctx.sql(Column(self.source, self.field.column_name))
6780
6781
6782def sort_models(models):
6783    models = set(models)
6784    seen = set()
6785    ordering = []
6786    def dfs(model):
6787        if model in models and model not in seen:
6788            seen.add(model)
6789            for foreign_key, rel_model in model._meta.refs.items():
6790                # Do not depth-first search deferred foreign-keys as this can
6791                # cause tables to be created in the incorrect order.
6792                if not foreign_key.deferred:
6793                    dfs(rel_model)
6794            if model._meta.depends_on:
6795                for dependency in model._meta.depends_on:
6796                    dfs(dependency)
6797            ordering.append(model)
6798
6799    names = lambda m: (m._meta.name, m._meta.table_name)
6800    for m in sorted(models, key=names):
6801        dfs(m)
6802    return ordering
6803
6804
6805class _ModelQueryHelper(object):
6806    default_row_type = ROW.MODEL
6807
6808    def __init__(self, *args, **kwargs):
6809        super(_ModelQueryHelper, self).__init__(*args, **kwargs)
6810        if not self._database:
6811            self._database = self.model._meta.database
6812
6813    @Node.copy
6814    def objects(self, constructor=None):
6815        self._row_type = ROW.CONSTRUCTOR
6816        self._constructor = self.model if constructor is None else constructor
6817
6818    def _get_cursor_wrapper(self, cursor):
6819        row_type = self._row_type or self.default_row_type
6820        if row_type == ROW.MODEL:
6821            return self._get_model_cursor_wrapper(cursor)
6822        elif row_type == ROW.DICT:
6823            return ModelDictCursorWrapper(cursor, self.model, self._returning)
6824        elif row_type == ROW.TUPLE:
6825            return ModelTupleCursorWrapper(cursor, self.model, self._returning)
6826        elif row_type == ROW.NAMED_TUPLE:
6827            return ModelNamedTupleCursorWrapper(cursor, self.model,
6828                                                self._returning)
6829        elif row_type == ROW.CONSTRUCTOR:
6830            return ModelObjectCursorWrapper(cursor, self.model,
6831                                            self._returning, self._constructor)
6832        else:
6833            raise ValueError('Unrecognized row type: "%s".' % row_type)
6834
6835    def _get_model_cursor_wrapper(self, cursor):
6836        return ModelObjectCursorWrapper(cursor, self.model, [], self.model)
6837
6838
6839class ModelRaw(_ModelQueryHelper, RawQuery):
6840    def __init__(self, model, sql, params, **kwargs):
6841        self.model = model
6842        self._returning = ()
6843        super(ModelRaw, self).__init__(sql=sql, params=params, **kwargs)
6844
6845    def get(self):
6846        try:
6847            return self.execute()[0]
6848        except IndexError:
6849            sql, params = self.sql()
6850            raise self.model.DoesNotExist('%s instance matching query does '
6851                                          'not exist:\nSQL: %s\nParams: %s' %
6852                                          (self.model, sql, params))
6853
6854
6855class BaseModelSelect(_ModelQueryHelper):
6856    def union_all(self, rhs):
6857        return ModelCompoundSelectQuery(self.model, self, 'UNION ALL', rhs)
6858    __add__ = union_all
6859
6860    def union(self, rhs):
6861        return ModelCompoundSelectQuery(self.model, self, 'UNION', rhs)
6862    __or__ = union
6863
6864    def intersect(self, rhs):
6865        return ModelCompoundSelectQuery(self.model, self, 'INTERSECT', rhs)
6866    __and__ = intersect
6867
6868    def except_(self, rhs):
6869        return ModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs)
6870    __sub__ = except_
6871
6872    def __iter__(self):
6873        if not self._cursor_wrapper:
6874            self.execute()
6875        return iter(self._cursor_wrapper)
6876
6877    def prefetch(self, *subqueries):
6878        return prefetch(self, *subqueries)
6879
6880    def get(self, database=None):
6881        clone = self.paginate(1, 1)
6882        clone._cursor_wrapper = None
6883        try:
6884            return clone.execute(database)[0]
6885        except IndexError:
6886            sql, params = clone.sql()
6887            raise self.model.DoesNotExist('%s instance matching query does '
6888                                          'not exist:\nSQL: %s\nParams: %s' %
6889                                          (clone.model, sql, params))
6890
6891    @Node.copy
6892    def group_by(self, *columns):
6893        grouping = []
6894        for column in columns:
6895            if is_model(column):
6896                grouping.extend(column._meta.sorted_fields)
6897            elif isinstance(column, Table):
6898                if not column._columns:
6899                    raise ValueError('Cannot pass a table to group_by() that '
6900                                     'does not have columns explicitly '
6901                                     'declared.')
6902                grouping.extend([getattr(column, col_name)
6903                                 for col_name in column._columns])
6904            else:
6905                grouping.append(column)
6906        self._group_by = grouping
6907
6908
6909class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery):
6910    def __init__(self, model, *args, **kwargs):
6911        self.model = model
6912        super(ModelCompoundSelectQuery, self).__init__(*args, **kwargs)
6913
6914    def _get_model_cursor_wrapper(self, cursor):
6915        return self.lhs._get_model_cursor_wrapper(cursor)
6916
6917
6918def _normalize_model_select(fields_or_models):
6919    fields = []
6920    for fm in fields_or_models:
6921        if is_model(fm):
6922            fields.extend(fm._meta.sorted_fields)
6923        elif isinstance(fm, ModelAlias):
6924            fields.extend(fm.get_field_aliases())
6925        elif isinstance(fm, Table) and fm._columns:
6926            fields.extend([getattr(fm, col) for col in fm._columns])
6927        else:
6928            fields.append(fm)
6929    return fields
6930
6931
6932class ModelSelect(BaseModelSelect, Select):
6933    def __init__(self, model, fields_or_models, is_default=False):
6934        self.model = self._join_ctx = model
6935        self._joins = {}
6936        self._is_default = is_default
6937        fields = _normalize_model_select(fields_or_models)
6938        super(ModelSelect, self).__init__([model], fields)
6939
6940    def clone(self):
6941        clone = super(ModelSelect, self).clone()
6942        if clone._joins:
6943            clone._joins = dict(clone._joins)
6944        return clone
6945
6946    def select(self, *fields_or_models):
6947        if fields_or_models or not self._is_default:
6948            self._is_default = False
6949            fields = _normalize_model_select(fields_or_models)
6950            return super(ModelSelect, self).select(*fields)
6951        return self
6952
6953    def switch(self, ctx=None):
6954        self._join_ctx = self.model if ctx is None else ctx
6955        return self
6956
6957    def _get_model(self, src):
6958        if is_model(src):
6959            return src, True
6960        elif isinstance(src, Table) and src._model:
6961            return src._model, False
6962        elif isinstance(src, ModelAlias):
6963            return src.model, False
6964        elif isinstance(src, ModelSelect):
6965            return src.model, False
6966        return None, False
6967
6968    def _normalize_join(self, src, dest, on, attr):
6969        # Allow "on" expression to have an alias that determines the
6970        # destination attribute for the joined data.
6971        on_alias = isinstance(on, Alias)
6972        if on_alias:
6973            attr = attr or on._alias
6974            on = on.alias()
6975
6976        # Obtain references to the source and destination models being joined.
6977        src_model, src_is_model = self._get_model(src)
6978        dest_model, dest_is_model = self._get_model(dest)
6979
6980        if src_model and dest_model:
6981            self._join_ctx = dest
6982            constructor = dest_model
6983
6984            # In the case where the "on" clause is a Column or Field, we will
6985            # convert that field into the appropriate predicate expression.
6986            if not (src_is_model and dest_is_model) and isinstance(on, Column):
6987                if on.source is src:
6988                    to_field = src_model._meta.columns[on.name]
6989                elif on.source is dest:
6990                    to_field = dest_model._meta.columns[on.name]
6991                else:
6992                    raise AttributeError('"on" clause Column %s does not '
6993                                         'belong to %s or %s.' %
6994                                         (on, src_model, dest_model))
6995                on = None
6996            elif isinstance(on, Field):
6997                to_field = on
6998                on = None
6999            else:
7000                to_field = None
7001
7002            fk_field, is_backref = self._generate_on_clause(
7003                src_model, dest_model, to_field, on)
7004
7005            if on is None:
7006                src_attr = 'name' if src_is_model else 'column_name'
7007                dest_attr = 'name' if dest_is_model else 'column_name'
7008                if is_backref:
7009                    lhs = getattr(dest, getattr(fk_field, dest_attr))
7010                    rhs = getattr(src, getattr(fk_field.rel_field, src_attr))
7011                else:
7012                    lhs = getattr(src, getattr(fk_field, src_attr))
7013                    rhs = getattr(dest, getattr(fk_field.rel_field, dest_attr))
7014                on = (lhs == rhs)
7015
7016            if not attr:
7017                if fk_field is not None and not is_backref:
7018                    attr = fk_field.name
7019                else:
7020                    attr = dest_model._meta.name
7021            elif on_alias and fk_field is not None and \
7022                    attr == fk_field.object_id_name and not is_backref:
7023                raise ValueError('Cannot assign join alias to "%s", as this '
7024                                 'attribute is the object_id_name for the '
7025                                 'foreign-key field "%s"' % (attr, fk_field))
7026
7027        elif isinstance(dest, Source):
7028            constructor = dict
7029            attr = attr or dest._alias
7030            if not attr and isinstance(dest, Table):
7031                attr = attr or dest.__name__
7032
7033        return (on, attr, constructor)
7034
7035    def _generate_on_clause(self, src, dest, to_field=None, on=None):
7036        meta = src._meta
7037        is_backref = fk_fields = False
7038
7039        # Get all the foreign keys between source and dest, and determine if
7040        # the join is via a back-reference.
7041        if dest in meta.model_refs:
7042            fk_fields = meta.model_refs[dest]
7043        elif dest in meta.model_backrefs:
7044            fk_fields = meta.model_backrefs[dest]
7045            is_backref = True
7046
7047        if not fk_fields:
7048            if on is not None:
7049                return None, False
7050            raise ValueError('Unable to find foreign key between %s and %s. '
7051                             'Please specify an explicit join condition.' %
7052                             (src, dest))
7053        elif to_field is not None:
7054            # If the foreign-key field was specified explicitly, remove all
7055            # other foreign-key fields from the list.
7056            target = (to_field.field if isinstance(to_field, FieldAlias)
7057                      else to_field)
7058            fk_fields = [f for f in fk_fields if (
7059                         (f is target) or
7060                         (is_backref and f.rel_field is to_field))]
7061
7062        if len(fk_fields) == 1:
7063            return fk_fields[0], is_backref
7064
7065        if on is None:
7066            # If multiple foreign-keys exist, try using the FK whose name
7067            # matches that of the related model. If not, raise an error as this
7068            # is ambiguous.
7069            for fk in fk_fields:
7070                if fk.name == dest._meta.name:
7071                    return fk, is_backref
7072
7073            raise ValueError('More than one foreign key between %s and %s.'
7074                             ' Please specify which you are joining on.' %
7075                             (src, dest))
7076
7077        # If there are multiple foreign-keys to choose from and the join
7078        # predicate is an expression, we'll try to figure out which
7079        # foreign-key field we're joining on so that we can assign to the
7080        # correct attribute when resolving the model graph.
7081        to_field = None
7082        if isinstance(on, Expression):
7083            lhs, rhs = on.lhs, on.rhs
7084            # Coerce to set() so that we force Python to compare using the
7085            # object's hash rather than equality test, which returns a
7086            # false-positive due to overriding __eq__.
7087            fk_set = set(fk_fields)
7088
7089            if isinstance(lhs, Field):
7090                lhs_f = lhs.field if isinstance(lhs, FieldAlias) else lhs
7091                if lhs_f in fk_set:
7092                    to_field = lhs_f
7093            elif isinstance(rhs, Field):
7094                rhs_f = rhs.field if isinstance(rhs, FieldAlias) else rhs
7095                if rhs_f in fk_set:
7096                    to_field = rhs_f
7097
7098        return to_field, False
7099
7100    @Node.copy
7101    def join(self, dest, join_type=JOIN.INNER, on=None, src=None, attr=None):
7102        src = self._join_ctx if src is None else src
7103
7104        if join_type == JOIN.LATERAL or join_type == JOIN.LEFT_LATERAL:
7105            on = True
7106        elif join_type != JOIN.CROSS:
7107            on, attr, constructor = self._normalize_join(src, dest, on, attr)
7108            if attr:
7109                self._joins.setdefault(src, [])
7110                self._joins[src].append((dest, attr, constructor, join_type))
7111        elif on is not None:
7112            raise ValueError('Cannot specify on clause with cross join.')
7113
7114        if not self._from_list:
7115            raise ValueError('No sources to join on.')
7116
7117        item = self._from_list.pop()
7118        self._from_list.append(Join(item, dest, join_type, on))
7119
7120    def join_from(self, src, dest, join_type=JOIN.INNER, on=None, attr=None):
7121        return self.join(dest, join_type, on, src, attr)
7122
7123    def _get_model_cursor_wrapper(self, cursor):
7124        if len(self._from_list) == 1 and not self._joins:
7125            return ModelObjectCursorWrapper(cursor, self.model,
7126                                            self._returning, self.model)
7127        return ModelCursorWrapper(cursor, self.model, self._returning,
7128                                  self._from_list, self._joins)
7129
7130    def ensure_join(self, lm, rm, on=None, **join_kwargs):
7131        join_ctx = self._join_ctx
7132        for dest, _, constructor, _ in self._joins.get(lm, []):
7133            if dest == rm:
7134                return self
7135        return self.switch(lm).join(rm, on=on, **join_kwargs).switch(join_ctx)
7136
7137    def convert_dict_to_node(self, qdict):
7138        accum = []
7139        joins = []
7140        fks = (ForeignKeyField, BackrefAccessor)
7141        for key, value in sorted(qdict.items()):
7142            curr = self.model
7143            if '__' in key and key.rsplit('__', 1)[1] in DJANGO_MAP:
7144                key, op = key.rsplit('__', 1)
7145                op = DJANGO_MAP[op]
7146            elif value is None:
7147                op = DJANGO_MAP['is']
7148            else:
7149                op = DJANGO_MAP['eq']
7150
7151            if '__' not in key:
7152                # Handle simplest case. This avoids joining over-eagerly when a
7153                # direct FK lookup is all that is required.
7154                model_attr = getattr(curr, key)
7155            else:
7156                for piece in key.split('__'):
7157                    for dest, attr, _, _ in self._joins.get(curr, ()):
7158                        if attr == piece or (isinstance(dest, ModelAlias) and
7159                                             dest.alias == piece):
7160                            curr = dest
7161                            break
7162                    else:
7163                        model_attr = getattr(curr, piece)
7164                        if value is not None and isinstance(model_attr, fks):
7165                            curr = model_attr.rel_model
7166                            joins.append(model_attr)
7167            accum.append(op(model_attr, value))
7168        return accum, joins
7169
7170    def filter(self, *args, **kwargs):
7171        # normalize args and kwargs into a new expression
7172        if args and kwargs:
7173            dq_node = (reduce(operator.and_, [a.clone() for a in args]) &
7174                       DQ(**kwargs))
7175        elif args:
7176            dq_node = (reduce(operator.and_, [a.clone() for a in args]) &
7177                       ColumnBase())
7178        elif kwargs:
7179            dq_node = DQ(**kwargs) & ColumnBase()
7180        else:
7181            return self.clone()
7182
7183        # dq_node should now be an Expression, lhs = Node(), rhs = ...
7184        q = collections.deque([dq_node])
7185        dq_joins = []
7186        seen_joins = set()
7187        while q:
7188            curr = q.popleft()
7189            if not isinstance(curr, Expression):
7190                continue
7191            for side, piece in (('lhs', curr.lhs), ('rhs', curr.rhs)):
7192                if isinstance(piece, DQ):
7193                    query, joins = self.convert_dict_to_node(piece.query)
7194                    for join in joins:
7195                        if join not in seen_joins:
7196                            dq_joins.append(join)
7197                            seen_joins.add(join)
7198                    expression = reduce(operator.and_, query)
7199                    # Apply values from the DQ object.
7200                    if piece._negated:
7201                        expression = Negated(expression)
7202                    #expression._alias = piece._alias
7203                    setattr(curr, side, expression)
7204                else:
7205                    q.append(piece)
7206
7207        if not args or not kwargs:
7208            dq_node = dq_node.lhs
7209
7210        query = self.clone()
7211        for field in dq_joins:
7212            if isinstance(field, ForeignKeyField):
7213                lm, rm = field.model, field.rel_model
7214                field_obj = field
7215            elif isinstance(field, BackrefAccessor):
7216                lm, rm = field.model, field.rel_model
7217                field_obj = field.field
7218            query = query.ensure_join(lm, rm, field_obj)
7219        return query.where(dq_node)
7220
7221    def create_table(self, name, safe=True, **meta):
7222        return self.model._schema.create_table_as(name, self, safe, **meta)
7223
7224    def __sql_selection__(self, ctx, is_subquery=False):
7225        if self._is_default and is_subquery and len(self._returning) > 1 and \
7226           self.model._meta.primary_key is not False:
7227            return ctx.sql(self.model._meta.primary_key)
7228
7229        return ctx.sql(CommaNodeList(self._returning))
7230
7231
7232class NoopModelSelect(ModelSelect):
7233    def __sql__(self, ctx):
7234        return self.model._meta.database.get_noop_select(ctx)
7235
7236    def _get_cursor_wrapper(self, cursor):
7237        return CursorWrapper(cursor)
7238
7239
7240class _ModelWriteQueryHelper(_ModelQueryHelper):
7241    def __init__(self, model, *args, **kwargs):
7242        self.model = model
7243        super(_ModelWriteQueryHelper, self).__init__(model, *args, **kwargs)
7244
7245    def returning(self, *returning):
7246        accum = []
7247        for item in returning:
7248            if is_model(item):
7249                accum.extend(item._meta.sorted_fields)
7250            else:
7251                accum.append(item)
7252        return super(_ModelWriteQueryHelper, self).returning(*accum)
7253
7254    def _set_table_alias(self, ctx):
7255        table = self.model._meta.table
7256        ctx.alias_manager[table] = table.__name__
7257
7258
7259class ModelUpdate(_ModelWriteQueryHelper, Update):
7260    pass
7261
7262
7263class ModelInsert(_ModelWriteQueryHelper, Insert):
7264    default_row_type = ROW.TUPLE
7265
7266    def __init__(self, *args, **kwargs):
7267        super(ModelInsert, self).__init__(*args, **kwargs)
7268        if self._returning is None and self.model._meta.database is not None:
7269            if self.model._meta.database.returning_clause:
7270                self._returning = self.model._meta.get_primary_keys()
7271
7272    def returning(self, *returning):
7273        # By default ModelInsert will yield a `tuple` containing the
7274        # primary-key of the newly inserted row. But if we are explicitly
7275        # specifying a returning clause and have not set a row type, we will
7276        # default to returning model instances instead.
7277        if returning and self._row_type is None:
7278            self._row_type = ROW.MODEL
7279        return super(ModelInsert, self).returning(*returning)
7280
7281    def get_default_data(self):
7282        return self.model._meta.defaults
7283
7284    def get_default_columns(self):
7285        fields = self.model._meta.sorted_fields
7286        return fields[1:] if self.model._meta.auto_increment else fields
7287
7288
7289class ModelDelete(_ModelWriteQueryHelper, Delete):
7290    pass
7291
7292
7293class ManyToManyQuery(ModelSelect):
7294    def __init__(self, instance, accessor, rel, *args, **kwargs):
7295        self._instance = instance
7296        self._accessor = accessor
7297        self._src_attr = accessor.src_fk.rel_field.name
7298        self._dest_attr = accessor.dest_fk.rel_field.name
7299        super(ManyToManyQuery, self).__init__(rel, (rel,), *args, **kwargs)
7300
7301    def _id_list(self, model_or_id_list):
7302        if isinstance(model_or_id_list[0], Model):
7303            return [getattr(obj, self._dest_attr) for obj in model_or_id_list]
7304        return model_or_id_list
7305
7306    def add(self, value, clear_existing=False):
7307        if clear_existing:
7308            self.clear()
7309
7310        accessor = self._accessor
7311        src_id = getattr(self._instance, self._src_attr)
7312        if isinstance(value, SelectQuery):
7313            query = value.columns(
7314                Value(src_id),
7315                accessor.dest_fk.rel_field)
7316            accessor.through_model.insert_from(
7317                fields=[accessor.src_fk, accessor.dest_fk],
7318                query=query).execute()
7319        else:
7320            value = ensure_tuple(value)
7321            if not value: return
7322
7323            inserts = [{
7324                accessor.src_fk.name: src_id,
7325                accessor.dest_fk.name: rel_id}
7326                for rel_id in self._id_list(value)]
7327            accessor.through_model.insert_many(inserts).execute()
7328
7329    def remove(self, value):
7330        src_id = getattr(self._instance, self._src_attr)
7331        if isinstance(value, SelectQuery):
7332            column = getattr(value.model, self._dest_attr)
7333            subquery = value.columns(column)
7334            return (self._accessor.through_model
7335                    .delete()
7336                    .where(
7337                        (self._accessor.dest_fk << subquery) &
7338                        (self._accessor.src_fk == src_id))
7339                    .execute())
7340        else:
7341            value = ensure_tuple(value)
7342            if not value:
7343                return
7344            return (self._accessor.through_model
7345                    .delete()
7346                    .where(
7347                        (self._accessor.dest_fk << self._id_list(value)) &
7348                        (self._accessor.src_fk == src_id))
7349                    .execute())
7350
7351    def clear(self):
7352        src_id = getattr(self._instance, self._src_attr)
7353        return (self._accessor.through_model
7354                .delete()
7355                .where(self._accessor.src_fk == src_id)
7356                .execute())
7357
7358
7359def safe_python_value(conv_func):
7360    def validate(value):
7361        try:
7362            return conv_func(value)
7363        except (TypeError, ValueError):
7364            return value
7365    return validate
7366
7367
7368class BaseModelCursorWrapper(DictCursorWrapper):
7369    def __init__(self, cursor, model, columns):
7370        super(BaseModelCursorWrapper, self).__init__(cursor)
7371        self.model = model
7372        self.select = columns or []
7373
7374    def _initialize_columns(self):
7375        combined = self.model._meta.combined
7376        table = self.model._meta.table
7377        description = self.cursor.description
7378
7379        self.ncols = len(self.cursor.description)
7380        self.columns = []
7381        self.converters = converters = [None] * self.ncols
7382        self.fields = fields = [None] * self.ncols
7383
7384        for idx, description_item in enumerate(description):
7385            column = description_item[0]
7386            dot_index = column.find('.')
7387            if dot_index != -1:
7388                column = column[dot_index + 1:]
7389
7390            column = column.strip('")')
7391            self.columns.append(column)
7392            try:
7393                raw_node = self.select[idx]
7394            except IndexError:
7395                if column in combined:
7396                    raw_node = node = combined[column]
7397                else:
7398                    continue
7399            else:
7400                node = raw_node.unwrap()
7401
7402            # Heuristics used to attempt to get the field associated with a
7403            # given SELECT column, so that we can accurately convert the value
7404            # returned by the database-cursor into a Python object.
7405            if isinstance(node, Field):
7406                if raw_node._coerce:
7407                    converters[idx] = node.python_value
7408                fields[idx] = node
7409                if not raw_node.is_alias():
7410                    self.columns[idx] = node.name
7411            elif isinstance(node, ColumnBase) and raw_node._converter:
7412                converters[idx] = raw_node._converter
7413            elif isinstance(node, Function) and node._coerce:
7414                if node._python_value is not None:
7415                    converters[idx] = node._python_value
7416                elif node.arguments and isinstance(node.arguments[0], Node):
7417                    # If the first argument is a field or references a column
7418                    # on a Model, try using that field's conversion function.
7419                    # This usually works, but we use "safe_python_value()" so
7420                    # that if a TypeError or ValueError occurs during
7421                    # conversion we can just fall-back to the raw cursor value.
7422                    first = node.arguments[0].unwrap()
7423                    if isinstance(first, Entity):
7424                        path = first._path[-1]  # Try to look-up by name.
7425                        first = combined.get(path)
7426                    if isinstance(first, Field):
7427                        converters[idx] = safe_python_value(first.python_value)
7428            elif column in combined:
7429                if node._coerce:
7430                    converters[idx] = combined[column].python_value
7431                if isinstance(node, Column) and node.source == table:
7432                    fields[idx] = combined[column]
7433
7434    initialize = _initialize_columns
7435
7436    def process_row(self, row):
7437        raise NotImplementedError
7438
7439
7440class ModelDictCursorWrapper(BaseModelCursorWrapper):
7441    def process_row(self, row):
7442        result = {}
7443        columns, converters = self.columns, self.converters
7444        fields = self.fields
7445
7446        for i in range(self.ncols):
7447            attr = columns[i]
7448            if attr in result: continue  # Don't overwrite if we have dupes.
7449            if converters[i] is not None:
7450                result[attr] = converters[i](row[i])
7451            else:
7452                result[attr] = row[i]
7453
7454        return result
7455
7456
7457class ModelTupleCursorWrapper(ModelDictCursorWrapper):
7458    constructor = tuple
7459
7460    def process_row(self, row):
7461        columns, converters = self.columns, self.converters
7462        return self.constructor([
7463            (converters[i](row[i]) if converters[i] is not None else row[i])
7464            for i in range(self.ncols)])
7465
7466
7467class ModelNamedTupleCursorWrapper(ModelTupleCursorWrapper):
7468    def initialize(self):
7469        self._initialize_columns()
7470        attributes = []
7471        for i in range(self.ncols):
7472            attributes.append(self.columns[i])
7473        self.tuple_class = collections.namedtuple('Row', attributes)
7474        self.constructor = lambda row: self.tuple_class(*row)
7475
7476
7477class ModelObjectCursorWrapper(ModelDictCursorWrapper):
7478    def __init__(self, cursor, model, select, constructor):
7479        self.constructor = constructor
7480        self.is_model = is_model(constructor)
7481        super(ModelObjectCursorWrapper, self).__init__(cursor, model, select)
7482
7483    def process_row(self, row):
7484        data = super(ModelObjectCursorWrapper, self).process_row(row)
7485        if self.is_model:
7486            # Clear out any dirty fields before returning to the user.
7487            obj = self.constructor(__no_default__=1, **data)
7488            obj._dirty.clear()
7489            return obj
7490        else:
7491            return self.constructor(**data)
7492
7493
7494class ModelCursorWrapper(BaseModelCursorWrapper):
7495    def __init__(self, cursor, model, select, from_list, joins):
7496        super(ModelCursorWrapper, self).__init__(cursor, model, select)
7497        self.from_list = from_list
7498        self.joins = joins
7499
7500    def initialize(self):
7501        self._initialize_columns()
7502        selected_src = set([field.model for field in self.fields
7503                            if field is not None])
7504        select, columns = self.select, self.columns
7505
7506        self.key_to_constructor = {self.model: self.model}
7507        self.src_is_dest = {}
7508        self.src_to_dest = []
7509        accum = collections.deque(self.from_list)
7510        dests = set()
7511
7512        while accum:
7513            curr = accum.popleft()
7514            if isinstance(curr, Join):
7515                accum.append(curr.lhs)
7516                accum.append(curr.rhs)
7517                continue
7518
7519            if curr not in self.joins:
7520                continue
7521
7522            is_dict = isinstance(curr, dict)
7523            for key, attr, constructor, join_type in self.joins[curr]:
7524                if key not in self.key_to_constructor:
7525                    self.key_to_constructor[key] = constructor
7526
7527                    # (src, attr, dest, is_dict, join_type).
7528                    self.src_to_dest.append((curr, attr, key, is_dict,
7529                                             join_type))
7530                    dests.add(key)
7531                    accum.append(key)
7532
7533        # Ensure that we accommodate everything selected.
7534        for src in selected_src:
7535            if src not in self.key_to_constructor:
7536                if is_model(src):
7537                    self.key_to_constructor[src] = src
7538                elif isinstance(src, ModelAlias):
7539                    self.key_to_constructor[src] = src.model
7540
7541        # Indicate which sources are also dests.
7542        for src, _, dest, _, _ in self.src_to_dest:
7543            self.src_is_dest[src] = src in dests and (dest in selected_src
7544                                                      or src in selected_src)
7545
7546        self.column_keys = []
7547        for idx, node in enumerate(select):
7548            key = self.model
7549            field = self.fields[idx]
7550            if field is not None:
7551                if isinstance(field, FieldAlias):
7552                    key = field.source
7553                else:
7554                    key = field.model
7555            else:
7556                if isinstance(node, Node):
7557                    node = node.unwrap()
7558                if isinstance(node, Column):
7559                    key = node.source
7560
7561            self.column_keys.append(key)
7562
7563    def process_row(self, row):
7564        objects = {}
7565        object_list = []
7566        for key, constructor in self.key_to_constructor.items():
7567            objects[key] = constructor(__no_default__=True)
7568            object_list.append(objects[key])
7569
7570        default_instance = objects[self.model]
7571
7572        set_keys = set()
7573        for idx, key in enumerate(self.column_keys):
7574            # Get the instance corresponding to the selected column/value,
7575            # falling back to the "root" model instance.
7576            instance = objects.get(key, default_instance)
7577            column = self.columns[idx]
7578            value = row[idx]
7579            if value is not None:
7580                set_keys.add(key)
7581            if self.converters[idx]:
7582                value = self.converters[idx](value)
7583
7584            if isinstance(instance, dict):
7585                instance[column] = value
7586            else:
7587                setattr(instance, column, value)
7588
7589        # Need to do some analysis on the joins before this.
7590        for (src, attr, dest, is_dict, join_type) in self.src_to_dest:
7591            instance = objects[src]
7592            try:
7593                joined_instance = objects[dest]
7594            except KeyError:
7595                continue
7596
7597            # If no fields were set on the destination instance then do not
7598            # assign an "empty" instance.
7599            if instance is None or dest is None or \
7600               (dest not in set_keys and not self.src_is_dest.get(dest)):
7601                continue
7602
7603            # If no fields were set on either the source or the destination,
7604            # then we have nothing to do here.
7605            if instance not in set_keys and dest not in set_keys \
7606               and join_type.endswith('OUTER JOIN'):
7607                continue
7608
7609            if is_dict:
7610                instance[attr] = joined_instance
7611            else:
7612                setattr(instance, attr, joined_instance)
7613
7614        # When instantiating models from a cursor, we clear the dirty fields.
7615        for instance in object_list:
7616            if isinstance(instance, Model):
7617                instance._dirty.clear()
7618
7619        return objects[self.model]
7620
7621
7622class PrefetchQuery(collections.namedtuple('_PrefetchQuery', (
7623    'query', 'fields', 'is_backref', 'rel_models', 'field_to_name', 'model'))):
7624    def __new__(cls, query, fields=None, is_backref=None, rel_models=None,
7625                field_to_name=None, model=None):
7626        if fields:
7627            if is_backref:
7628                if rel_models is None:
7629                    rel_models = [field.model for field in fields]
7630                foreign_key_attrs = [field.rel_field.name for field in fields]
7631            else:
7632                if rel_models is None:
7633                    rel_models = [field.rel_model for field in fields]
7634                foreign_key_attrs = [field.name for field in fields]
7635            field_to_name = list(zip(fields, foreign_key_attrs))
7636        model = query.model
7637        return super(PrefetchQuery, cls).__new__(
7638            cls, query, fields, is_backref, rel_models, field_to_name, model)
7639
7640    def populate_instance(self, instance, id_map):
7641        if self.is_backref:
7642            for field in self.fields:
7643                identifier = instance.__data__[field.name]
7644                key = (field, identifier)
7645                if key in id_map:
7646                    setattr(instance, field.name, id_map[key])
7647        else:
7648            for field, attname in self.field_to_name:
7649                identifier = instance.__data__[field.rel_field.name]
7650                key = (field, identifier)
7651                rel_instances = id_map.get(key, [])
7652                for inst in rel_instances:
7653                    setattr(inst, attname, instance)
7654                    inst._dirty.clear()
7655                setattr(instance, field.backref, rel_instances)
7656
7657    def store_instance(self, instance, id_map):
7658        for field, attname in self.field_to_name:
7659            identity = field.rel_field.python_value(instance.__data__[attname])
7660            key = (field, identity)
7661            if self.is_backref:
7662                id_map[key] = instance
7663            else:
7664                id_map.setdefault(key, [])
7665                id_map[key].append(instance)
7666
7667
7668def prefetch_add_subquery(sq, subqueries):
7669    fixed_queries = [PrefetchQuery(sq)]
7670    for i, subquery in enumerate(subqueries):
7671        if isinstance(subquery, tuple):
7672            subquery, target_model = subquery
7673        else:
7674            target_model = None
7675        if not isinstance(subquery, Query) and is_model(subquery) or \
7676           isinstance(subquery, ModelAlias):
7677            subquery = subquery.select()
7678        subquery_model = subquery.model
7679        fks = backrefs = None
7680        for j in reversed(range(i + 1)):
7681            fixed = fixed_queries[j]
7682            last_query = fixed.query
7683            last_model = last_obj = fixed.model
7684            if isinstance(last_model, ModelAlias):
7685                last_model = last_model.model
7686            rels = subquery_model._meta.model_refs.get(last_model, [])
7687            if rels:
7688                fks = [getattr(subquery_model, fk.name) for fk in rels]
7689                pks = [getattr(last_obj, fk.rel_field.name) for fk in rels]
7690            else:
7691                backrefs = subquery_model._meta.model_backrefs.get(last_model)
7692            if (fks or backrefs) and ((target_model is last_obj) or
7693                                      (target_model is None)):
7694                break
7695
7696        if not fks and not backrefs:
7697            tgt_err = ' using %s' % target_model if target_model else ''
7698            raise AttributeError('Error: unable to find foreign key for '
7699                                 'query: %s%s' % (subquery, tgt_err))
7700
7701        dest = (target_model,) if target_model else None
7702
7703        if fks:
7704            expr = reduce(operator.or_, [
7705                (fk << last_query.select(pk))
7706                for (fk, pk) in zip(fks, pks)])
7707            subquery = subquery.where(expr)
7708            fixed_queries.append(PrefetchQuery(subquery, fks, False, dest))
7709        elif backrefs:
7710            expressions = []
7711            for backref in backrefs:
7712                rel_field = getattr(subquery_model, backref.rel_field.name)
7713                fk_field = getattr(last_obj, backref.name)
7714                expressions.append(rel_field << last_query.select(fk_field))
7715            subquery = subquery.where(reduce(operator.or_, expressions))
7716            fixed_queries.append(PrefetchQuery(subquery, backrefs, True, dest))
7717
7718    return fixed_queries
7719
7720
7721def prefetch(sq, *subqueries):
7722    if not subqueries:
7723        return sq
7724
7725    fixed_queries = prefetch_add_subquery(sq, subqueries)
7726    deps = {}
7727    rel_map = {}
7728    for pq in reversed(fixed_queries):
7729        query_model = pq.model
7730        if pq.fields:
7731            for rel_model in pq.rel_models:
7732                rel_map.setdefault(rel_model, [])
7733                rel_map[rel_model].append(pq)
7734
7735        deps.setdefault(query_model, {})
7736        id_map = deps[query_model]
7737        has_relations = bool(rel_map.get(query_model))
7738
7739        for instance in pq.query:
7740            if pq.fields:
7741                pq.store_instance(instance, id_map)
7742            if has_relations:
7743                for rel in rel_map[query_model]:
7744                    rel.populate_instance(instance, deps[rel.model])
7745
7746    return list(pq.query)
7747