1from bisect import bisect_left 2from bisect import bisect_right 3from contextlib import contextmanager 4from copy import deepcopy 5from functools import wraps 6from inspect import isclass 7import calendar 8import collections 9import datetime 10import decimal 11import hashlib 12import itertools 13import logging 14import operator 15import re 16import socket 17import struct 18import sys 19import threading 20import time 21import uuid 22import warnings 23try: 24 from collections.abc import Mapping 25except ImportError: 26 from collections import Mapping 27 28try: 29 from pysqlite3 import dbapi2 as pysq3 30except ImportError: 31 try: 32 from pysqlite2 import dbapi2 as pysq3 33 except ImportError: 34 pysq3 = None 35try: 36 import sqlite3 37except ImportError: 38 sqlite3 = pysq3 39else: 40 if pysq3 and pysq3.sqlite_version_info >= sqlite3.sqlite_version_info: 41 sqlite3 = pysq3 42try: 43 from psycopg2cffi import compat 44 compat.register() 45except ImportError: 46 pass 47try: 48 import psycopg2 49 from psycopg2 import extensions as pg_extensions 50 try: 51 from psycopg2 import errors as pg_errors 52 except ImportError: 53 pg_errors = None 54except ImportError: 55 psycopg2 = pg_errors = None 56try: 57 from psycopg2.extras import register_uuid as pg_register_uuid 58 pg_register_uuid() 59except Exception: 60 pass 61 62mysql_passwd = False 63try: 64 import pymysql as mysql 65except ImportError: 66 try: 67 import MySQLdb as mysql 68 mysql_passwd = True 69 except ImportError: 70 mysql = None 71 72 73__version__ = '3.14.4' 74__all__ = [ 75 'AsIs', 76 'AutoField', 77 'BareField', 78 'BigAutoField', 79 'BigBitField', 80 'BigIntegerField', 81 'BinaryUUIDField', 82 'BitField', 83 'BlobField', 84 'BooleanField', 85 'Case', 86 'Cast', 87 'CharField', 88 'Check', 89 'chunked', 90 'Column', 91 'CompositeKey', 92 'Context', 93 'Database', 94 'DatabaseError', 95 'DatabaseProxy', 96 'DataError', 97 'DateField', 98 'DateTimeField', 99 'DecimalField', 100 'DeferredForeignKey', 101 'DeferredThroughModel', 102 'DJANGO_MAP', 103 'DoesNotExist', 104 'DoubleField', 105 'DQ', 106 'EXCLUDED', 107 'Field', 108 'FixedCharField', 109 'FloatField', 110 'fn', 111 'ForeignKeyField', 112 'IdentityField', 113 'ImproperlyConfigured', 114 'Index', 115 'IntegerField', 116 'IntegrityError', 117 'InterfaceError', 118 'InternalError', 119 'IPField', 120 'JOIN', 121 'ManyToManyField', 122 'Model', 123 'ModelIndex', 124 'MySQLDatabase', 125 'NotSupportedError', 126 'OP', 127 'OperationalError', 128 'PostgresqlDatabase', 129 'PrimaryKeyField', # XXX: Deprecated, change to AutoField. 130 'prefetch', 131 'ProgrammingError', 132 'Proxy', 133 'QualifiedNames', 134 'SchemaManager', 135 'SmallIntegerField', 136 'Select', 137 'SQL', 138 'SqliteDatabase', 139 'Table', 140 'TextField', 141 'TimeField', 142 'TimestampField', 143 'Tuple', 144 'UUIDField', 145 'Value', 146 'ValuesList', 147 'Window', 148] 149 150try: # Python 2.7+ 151 from logging import NullHandler 152except ImportError: 153 class NullHandler(logging.Handler): 154 def emit(self, record): 155 pass 156 157logger = logging.getLogger('peewee') 158logger.addHandler(NullHandler()) 159 160 161if sys.version_info[0] == 2: 162 text_type = unicode 163 bytes_type = str 164 buffer_type = buffer 165 izip_longest = itertools.izip_longest 166 callable_ = callable 167 multi_types = (list, tuple, frozenset, set) 168 exec('def reraise(tp, value, tb=None): raise tp, value, tb') 169 def print_(s): 170 sys.stdout.write(s) 171 sys.stdout.write('\n') 172else: 173 import builtins 174 try: 175 from collections.abc import Callable 176 except ImportError: 177 from collections import Callable 178 from functools import reduce 179 callable_ = lambda c: isinstance(c, Callable) 180 text_type = str 181 bytes_type = bytes 182 buffer_type = memoryview 183 basestring = str 184 long = int 185 multi_types = (list, tuple, frozenset, set, range) 186 print_ = getattr(builtins, 'print') 187 izip_longest = itertools.zip_longest 188 def reraise(tp, value, tb=None): 189 if value.__traceback__ is not tb: 190 raise value.with_traceback(tb) 191 raise value 192 193 194if sqlite3: 195 sqlite3.register_adapter(decimal.Decimal, str) 196 sqlite3.register_adapter(datetime.date, str) 197 sqlite3.register_adapter(datetime.time, str) 198 __sqlite_version__ = sqlite3.sqlite_version_info 199else: 200 __sqlite_version__ = (0, 0, 0) 201 202 203__date_parts__ = set(('year', 'month', 'day', 'hour', 'minute', 'second')) 204 205# Sqlite does not support the `date_part` SQL function, so we will define an 206# implementation in python. 207__sqlite_datetime_formats__ = ( 208 '%Y-%m-%d %H:%M:%S', 209 '%Y-%m-%d %H:%M:%S.%f', 210 '%Y-%m-%d', 211 '%H:%M:%S', 212 '%H:%M:%S.%f', 213 '%H:%M') 214 215__sqlite_date_trunc__ = { 216 'year': '%Y-01-01 00:00:00', 217 'month': '%Y-%m-01 00:00:00', 218 'day': '%Y-%m-%d 00:00:00', 219 'hour': '%Y-%m-%d %H:00:00', 220 'minute': '%Y-%m-%d %H:%M:00', 221 'second': '%Y-%m-%d %H:%M:%S'} 222 223__mysql_date_trunc__ = __sqlite_date_trunc__.copy() 224__mysql_date_trunc__['minute'] = '%Y-%m-%d %H:%i:00' 225__mysql_date_trunc__['second'] = '%Y-%m-%d %H:%i:%S' 226 227def _sqlite_date_part(lookup_type, datetime_string): 228 assert lookup_type in __date_parts__ 229 if not datetime_string: 230 return 231 dt = format_date_time(datetime_string, __sqlite_datetime_formats__) 232 return getattr(dt, lookup_type) 233 234def _sqlite_date_trunc(lookup_type, datetime_string): 235 assert lookup_type in __sqlite_date_trunc__ 236 if not datetime_string: 237 return 238 dt = format_date_time(datetime_string, __sqlite_datetime_formats__) 239 return dt.strftime(__sqlite_date_trunc__[lookup_type]) 240 241 242def __deprecated__(s): 243 warnings.warn(s, DeprecationWarning) 244 245 246class attrdict(dict): 247 def __getattr__(self, attr): 248 try: 249 return self[attr] 250 except KeyError: 251 raise AttributeError(attr) 252 def __setattr__(self, attr, value): self[attr] = value 253 def __iadd__(self, rhs): self.update(rhs); return self 254 def __add__(self, rhs): d = attrdict(self); d.update(rhs); return d 255 256SENTINEL = object() 257 258#: Operations for use in SQL expressions. 259OP = attrdict( 260 AND='AND', 261 OR='OR', 262 ADD='+', 263 SUB='-', 264 MUL='*', 265 DIV='/', 266 BIN_AND='&', 267 BIN_OR='|', 268 XOR='#', 269 MOD='%', 270 EQ='=', 271 LT='<', 272 LTE='<=', 273 GT='>', 274 GTE='>=', 275 NE='!=', 276 IN='IN', 277 NOT_IN='NOT IN', 278 IS='IS', 279 IS_NOT='IS NOT', 280 LIKE='LIKE', 281 ILIKE='ILIKE', 282 BETWEEN='BETWEEN', 283 REGEXP='REGEXP', 284 IREGEXP='IREGEXP', 285 CONCAT='||', 286 BITWISE_NEGATION='~') 287 288# To support "django-style" double-underscore filters, create a mapping between 289# operation name and operation code, e.g. "__eq" == OP.EQ. 290DJANGO_MAP = attrdict({ 291 'eq': operator.eq, 292 'lt': operator.lt, 293 'lte': operator.le, 294 'gt': operator.gt, 295 'gte': operator.ge, 296 'ne': operator.ne, 297 'in': operator.lshift, 298 'is': lambda l, r: Expression(l, OP.IS, r), 299 'like': lambda l, r: Expression(l, OP.LIKE, r), 300 'ilike': lambda l, r: Expression(l, OP.ILIKE, r), 301 'regexp': lambda l, r: Expression(l, OP.REGEXP, r), 302}) 303 304#: Mapping of field type to the data-type supported by the database. Databases 305#: may override or add to this list. 306FIELD = attrdict( 307 AUTO='INTEGER', 308 BIGAUTO='BIGINT', 309 BIGINT='BIGINT', 310 BLOB='BLOB', 311 BOOL='SMALLINT', 312 CHAR='CHAR', 313 DATE='DATE', 314 DATETIME='DATETIME', 315 DECIMAL='DECIMAL', 316 DEFAULT='', 317 DOUBLE='REAL', 318 FLOAT='REAL', 319 INT='INTEGER', 320 SMALLINT='SMALLINT', 321 TEXT='TEXT', 322 TIME='TIME', 323 UUID='TEXT', 324 UUIDB='BLOB', 325 VARCHAR='VARCHAR') 326 327#: Join helpers (for convenience) -- all join types are supported, this object 328#: is just to help avoid introducing errors by using strings everywhere. 329JOIN = attrdict( 330 INNER='INNER JOIN', 331 LEFT_OUTER='LEFT OUTER JOIN', 332 RIGHT_OUTER='RIGHT OUTER JOIN', 333 FULL='FULL JOIN', 334 FULL_OUTER='FULL OUTER JOIN', 335 CROSS='CROSS JOIN', 336 NATURAL='NATURAL JOIN', 337 LATERAL='LATERAL', 338 LEFT_LATERAL='LEFT JOIN LATERAL') 339 340# Row representations. 341ROW = attrdict( 342 TUPLE=1, 343 DICT=2, 344 NAMED_TUPLE=3, 345 CONSTRUCTOR=4, 346 MODEL=5) 347 348SCOPE_NORMAL = 1 349SCOPE_SOURCE = 2 350SCOPE_VALUES = 4 351SCOPE_CTE = 8 352SCOPE_COLUMN = 16 353 354# Rules for parentheses around subqueries in compound select. 355CSQ_PARENTHESES_NEVER = 0 356CSQ_PARENTHESES_ALWAYS = 1 357CSQ_PARENTHESES_UNNESTED = 2 358 359# Regular expressions used to convert class names to snake-case table names. 360# First regex handles acronym followed by word or initial lower-word followed 361# by a capitalized word. e.g. APIResponse -> API_Response / fooBar -> foo_Bar. 362# Second regex handles the normal case of two title-cased words. 363SNAKE_CASE_STEP1 = re.compile('(.)_*([A-Z][a-z]+)') 364SNAKE_CASE_STEP2 = re.compile('([a-z0-9])_*([A-Z])') 365 366# Helper functions that are used in various parts of the codebase. 367MODEL_BASE = '_metaclass_helper_' 368 369def with_metaclass(meta, base=object): 370 return meta(MODEL_BASE, (base,), {}) 371 372def merge_dict(source, overrides): 373 merged = source.copy() 374 if overrides: 375 merged.update(overrides) 376 return merged 377 378def quote(path, quote_chars): 379 if len(path) == 1: 380 return path[0].join(quote_chars) 381 return '.'.join([part.join(quote_chars) for part in path]) 382 383is_model = lambda o: isclass(o) and issubclass(o, Model) 384 385def ensure_tuple(value): 386 if value is not None: 387 return value if isinstance(value, (list, tuple)) else (value,) 388 389def ensure_entity(value): 390 if value is not None: 391 return value if isinstance(value, Node) else Entity(value) 392 393def make_snake_case(s): 394 first = SNAKE_CASE_STEP1.sub(r'\1_\2', s) 395 return SNAKE_CASE_STEP2.sub(r'\1_\2', first).lower() 396 397def chunked(it, n): 398 marker = object() 399 for group in (list(g) for g in izip_longest(*[iter(it)] * n, 400 fillvalue=marker)): 401 if group[-1] is marker: 402 del group[group.index(marker):] 403 yield group 404 405 406class _callable_context_manager(object): 407 def __call__(self, fn): 408 @wraps(fn) 409 def inner(*args, **kwargs): 410 with self: 411 return fn(*args, **kwargs) 412 return inner 413 414 415class Proxy(object): 416 """ 417 Create a proxy or placeholder for another object. 418 """ 419 __slots__ = ('obj', '_callbacks') 420 421 def __init__(self): 422 self._callbacks = [] 423 self.initialize(None) 424 425 def initialize(self, obj): 426 self.obj = obj 427 for callback in self._callbacks: 428 callback(obj) 429 430 def attach_callback(self, callback): 431 self._callbacks.append(callback) 432 return callback 433 434 def passthrough(method): 435 def inner(self, *args, **kwargs): 436 if self.obj is None: 437 raise AttributeError('Cannot use uninitialized Proxy.') 438 return getattr(self.obj, method)(*args, **kwargs) 439 return inner 440 441 # Allow proxy to be used as a context-manager. 442 __enter__ = passthrough('__enter__') 443 __exit__ = passthrough('__exit__') 444 445 def __getattr__(self, attr): 446 if self.obj is None: 447 raise AttributeError('Cannot use uninitialized Proxy.') 448 return getattr(self.obj, attr) 449 450 def __setattr__(self, attr, value): 451 if attr not in self.__slots__: 452 raise AttributeError('Cannot set attribute on proxy.') 453 return super(Proxy, self).__setattr__(attr, value) 454 455 456class DatabaseProxy(Proxy): 457 """ 458 Proxy implementation specifically for proxying `Database` objects. 459 """ 460 def connection_context(self): 461 return ConnectionContext(self) 462 def atomic(self, *args, **kwargs): 463 return _atomic(self, *args, **kwargs) 464 def manual_commit(self): 465 return _manual(self) 466 def transaction(self, *args, **kwargs): 467 return _transaction(self, *args, **kwargs) 468 def savepoint(self): 469 return _savepoint(self) 470 471 472class ModelDescriptor(object): pass 473 474 475# SQL Generation. 476 477 478class AliasManager(object): 479 __slots__ = ('_counter', '_current_index', '_mapping') 480 481 def __init__(self): 482 # A list of dictionaries containing mappings at various depths. 483 self._counter = 0 484 self._current_index = 0 485 self._mapping = [] 486 self.push() 487 488 @property 489 def mapping(self): 490 return self._mapping[self._current_index - 1] 491 492 def add(self, source): 493 if source not in self.mapping: 494 self._counter += 1 495 self[source] = 't%d' % self._counter 496 return self.mapping[source] 497 498 def get(self, source, any_depth=False): 499 if any_depth: 500 for idx in reversed(range(self._current_index)): 501 if source in self._mapping[idx]: 502 return self._mapping[idx][source] 503 return self.add(source) 504 505 def __getitem__(self, source): 506 return self.get(source) 507 508 def __setitem__(self, source, alias): 509 self.mapping[source] = alias 510 511 def push(self): 512 self._current_index += 1 513 if self._current_index > len(self._mapping): 514 self._mapping.append({}) 515 516 def pop(self): 517 if self._current_index == 1: 518 raise ValueError('Cannot pop() from empty alias manager.') 519 self._current_index -= 1 520 521 522class State(collections.namedtuple('_State', ('scope', 'parentheses', 523 'settings'))): 524 def __new__(cls, scope=SCOPE_NORMAL, parentheses=False, **kwargs): 525 return super(State, cls).__new__(cls, scope, parentheses, kwargs) 526 527 def __call__(self, scope=None, parentheses=None, **kwargs): 528 # Scope and settings are "inherited" (parentheses is not, however). 529 scope = self.scope if scope is None else scope 530 531 # Try to avoid unnecessary dict copying. 532 if kwargs and self.settings: 533 settings = self.settings.copy() # Copy original settings dict. 534 settings.update(kwargs) # Update copy with overrides. 535 elif kwargs: 536 settings = kwargs 537 else: 538 settings = self.settings 539 return State(scope, parentheses, **settings) 540 541 def __getattr__(self, attr_name): 542 return self.settings.get(attr_name) 543 544 545def __scope_context__(scope): 546 @contextmanager 547 def inner(self, **kwargs): 548 with self(scope=scope, **kwargs): 549 yield self 550 return inner 551 552 553class Context(object): 554 __slots__ = ('stack', '_sql', '_values', 'alias_manager', 'state') 555 556 def __init__(self, **settings): 557 self.stack = [] 558 self._sql = [] 559 self._values = [] 560 self.alias_manager = AliasManager() 561 self.state = State(**settings) 562 563 def as_new(self): 564 return Context(**self.state.settings) 565 566 def column_sort_key(self, item): 567 return item[0].get_sort_key(self) 568 569 @property 570 def scope(self): 571 return self.state.scope 572 573 @property 574 def parentheses(self): 575 return self.state.parentheses 576 577 @property 578 def subquery(self): 579 return self.state.subquery 580 581 def __call__(self, **overrides): 582 if overrides and overrides.get('scope') == self.scope: 583 del overrides['scope'] 584 585 self.stack.append(self.state) 586 self.state = self.state(**overrides) 587 return self 588 589 scope_normal = __scope_context__(SCOPE_NORMAL) 590 scope_source = __scope_context__(SCOPE_SOURCE) 591 scope_values = __scope_context__(SCOPE_VALUES) 592 scope_cte = __scope_context__(SCOPE_CTE) 593 scope_column = __scope_context__(SCOPE_COLUMN) 594 595 def __enter__(self): 596 if self.parentheses: 597 self.literal('(') 598 return self 599 600 def __exit__(self, exc_type, exc_val, exc_tb): 601 if self.parentheses: 602 self.literal(')') 603 self.state = self.stack.pop() 604 605 @contextmanager 606 def push_alias(self): 607 self.alias_manager.push() 608 yield 609 self.alias_manager.pop() 610 611 def sql(self, obj): 612 if isinstance(obj, (Node, Context)): 613 return obj.__sql__(self) 614 elif is_model(obj): 615 return obj._meta.table.__sql__(self) 616 else: 617 return self.sql(Value(obj)) 618 619 def literal(self, keyword): 620 self._sql.append(keyword) 621 return self 622 623 def value(self, value, converter=None, add_param=True): 624 if converter: 625 value = converter(value) 626 elif converter is None and self.state.converter: 627 # Explicitly check for None so that "False" can be used to signify 628 # that no conversion should be applied. 629 value = self.state.converter(value) 630 631 if isinstance(value, Node): 632 with self(converter=None): 633 return self.sql(value) 634 elif is_model(value): 635 # Under certain circumstances, we could end-up treating a model- 636 # class itself as a value. This check ensures that we drop the 637 # table alias into the query instead of trying to parameterize a 638 # model (for instance, passing a model as a function argument). 639 with self.scope_column(): 640 return self.sql(value) 641 642 self._values.append(value) 643 return self.literal(self.state.param or '?') if add_param else self 644 645 def __sql__(self, ctx): 646 ctx._sql.extend(self._sql) 647 ctx._values.extend(self._values) 648 return ctx 649 650 def parse(self, node): 651 return self.sql(node).query() 652 653 def query(self): 654 return ''.join(self._sql), self._values 655 656 657def query_to_string(query): 658 # NOTE: this function is not exported by default as it might be misused -- 659 # and this misuse could lead to sql injection vulnerabilities. This 660 # function is intended for debugging or logging purposes ONLY. 661 db = getattr(query, '_database', None) 662 if db is not None: 663 ctx = db.get_sql_context() 664 else: 665 ctx = Context() 666 667 sql, params = ctx.sql(query).query() 668 if not params: 669 return sql 670 671 param = ctx.state.param or '?' 672 if param == '?': 673 sql = sql.replace('?', '%s') 674 675 return sql % tuple(map(_query_val_transform, params)) 676 677def _query_val_transform(v): 678 # Interpolate parameters. 679 if isinstance(v, (text_type, datetime.datetime, datetime.date, 680 datetime.time)): 681 v = "'%s'" % v 682 elif isinstance(v, bytes_type): 683 try: 684 v = v.decode('utf8') 685 except UnicodeDecodeError: 686 v = v.decode('raw_unicode_escape') 687 v = "'%s'" % v 688 elif isinstance(v, int): 689 v = '%s' % int(v) # Also handles booleans -> 1 or 0. 690 elif v is None: 691 v = 'NULL' 692 else: 693 v = str(v) 694 return v 695 696 697# AST. 698 699 700class Node(object): 701 _coerce = True 702 703 def clone(self): 704 obj = self.__class__.__new__(self.__class__) 705 obj.__dict__ = self.__dict__.copy() 706 return obj 707 708 def __sql__(self, ctx): 709 raise NotImplementedError 710 711 @staticmethod 712 def copy(method): 713 def inner(self, *args, **kwargs): 714 clone = self.clone() 715 method(clone, *args, **kwargs) 716 return clone 717 return inner 718 719 def coerce(self, _coerce=True): 720 if _coerce != self._coerce: 721 clone = self.clone() 722 clone._coerce = _coerce 723 return clone 724 return self 725 726 def is_alias(self): 727 return False 728 729 def unwrap(self): 730 return self 731 732 733class ColumnFactory(object): 734 __slots__ = ('node',) 735 736 def __init__(self, node): 737 self.node = node 738 739 def __getattr__(self, attr): 740 return Column(self.node, attr) 741 742 743class _DynamicColumn(object): 744 __slots__ = () 745 746 def __get__(self, instance, instance_type=None): 747 if instance is not None: 748 return ColumnFactory(instance) # Implements __getattr__(). 749 return self 750 751 752class _ExplicitColumn(object): 753 __slots__ = () 754 755 def __get__(self, instance, instance_type=None): 756 if instance is not None: 757 raise AttributeError( 758 '%s specifies columns explicitly, and does not support ' 759 'dynamic column lookups.' % instance) 760 return self 761 762 763class Source(Node): 764 c = _DynamicColumn() 765 766 def __init__(self, alias=None): 767 super(Source, self).__init__() 768 self._alias = alias 769 770 @Node.copy 771 def alias(self, name): 772 self._alias = name 773 774 def select(self, *columns): 775 if not columns: 776 columns = (SQL('*'),) 777 return Select((self,), columns) 778 779 def join(self, dest, join_type=JOIN.INNER, on=None): 780 return Join(self, dest, join_type, on) 781 782 def left_outer_join(self, dest, on=None): 783 return Join(self, dest, JOIN.LEFT_OUTER, on) 784 785 def cte(self, name, recursive=False, columns=None, materialized=None): 786 return CTE(name, self, recursive=recursive, columns=columns, 787 materialized=materialized) 788 789 def get_sort_key(self, ctx): 790 if self._alias: 791 return (self._alias,) 792 return (ctx.alias_manager[self],) 793 794 def apply_alias(self, ctx): 795 # If we are defining the source, include the "AS alias" declaration. An 796 # alias is created for the source if one is not already defined. 797 if ctx.scope == SCOPE_SOURCE: 798 if self._alias: 799 ctx.alias_manager[self] = self._alias 800 ctx.literal(' AS ').sql(Entity(ctx.alias_manager[self])) 801 return ctx 802 803 def apply_column(self, ctx): 804 if self._alias: 805 ctx.alias_manager[self] = self._alias 806 return ctx.sql(Entity(ctx.alias_manager[self])) 807 808 809class _HashableSource(object): 810 def __init__(self, *args, **kwargs): 811 super(_HashableSource, self).__init__(*args, **kwargs) 812 self._update_hash() 813 814 @Node.copy 815 def alias(self, name): 816 self._alias = name 817 self._update_hash() 818 819 def _update_hash(self): 820 self._hash = self._get_hash() 821 822 def _get_hash(self): 823 return hash((self.__class__, self._path, self._alias)) 824 825 def __hash__(self): 826 return self._hash 827 828 def __eq__(self, other): 829 if isinstance(other, _HashableSource): 830 return self._hash == other._hash 831 return Expression(self, OP.EQ, other) 832 833 def __ne__(self, other): 834 if isinstance(other, _HashableSource): 835 return self._hash != other._hash 836 return Expression(self, OP.NE, other) 837 838 def _e(op): 839 def inner(self, rhs): 840 return Expression(self, op, rhs) 841 return inner 842 __lt__ = _e(OP.LT) 843 __le__ = _e(OP.LTE) 844 __gt__ = _e(OP.GT) 845 __ge__ = _e(OP.GTE) 846 847 848def __bind_database__(meth): 849 @wraps(meth) 850 def inner(self, *args, **kwargs): 851 result = meth(self, *args, **kwargs) 852 if self._database: 853 return result.bind(self._database) 854 return result 855 return inner 856 857 858def __join__(join_type=JOIN.INNER, inverted=False): 859 def method(self, other): 860 if inverted: 861 self, other = other, self 862 return Join(self, other, join_type=join_type) 863 return method 864 865 866class BaseTable(Source): 867 __and__ = __join__(JOIN.INNER) 868 __add__ = __join__(JOIN.LEFT_OUTER) 869 __sub__ = __join__(JOIN.RIGHT_OUTER) 870 __or__ = __join__(JOIN.FULL_OUTER) 871 __mul__ = __join__(JOIN.CROSS) 872 __rand__ = __join__(JOIN.INNER, inverted=True) 873 __radd__ = __join__(JOIN.LEFT_OUTER, inverted=True) 874 __rsub__ = __join__(JOIN.RIGHT_OUTER, inverted=True) 875 __ror__ = __join__(JOIN.FULL_OUTER, inverted=True) 876 __rmul__ = __join__(JOIN.CROSS, inverted=True) 877 878 879class _BoundTableContext(_callable_context_manager): 880 def __init__(self, table, database): 881 self.table = table 882 self.database = database 883 884 def __enter__(self): 885 self._orig_database = self.table._database 886 self.table.bind(self.database) 887 if self.table._model is not None: 888 self.table._model.bind(self.database) 889 return self.table 890 891 def __exit__(self, exc_type, exc_val, exc_tb): 892 self.table.bind(self._orig_database) 893 if self.table._model is not None: 894 self.table._model.bind(self._orig_database) 895 896 897class Table(_HashableSource, BaseTable): 898 def __init__(self, name, columns=None, primary_key=None, schema=None, 899 alias=None, _model=None, _database=None): 900 self.__name__ = name 901 self._columns = columns 902 self._primary_key = primary_key 903 self._schema = schema 904 self._path = (schema, name) if schema else (name,) 905 self._model = _model 906 self._database = _database 907 super(Table, self).__init__(alias=alias) 908 909 # Allow tables to restrict what columns are available. 910 if columns is not None: 911 self.c = _ExplicitColumn() 912 for column in columns: 913 setattr(self, column, Column(self, column)) 914 915 if primary_key: 916 col_src = self if self._columns else self.c 917 self.primary_key = getattr(col_src, primary_key) 918 else: 919 self.primary_key = None 920 921 def clone(self): 922 # Ensure a deep copy of the column instances. 923 return Table( 924 self.__name__, 925 columns=self._columns, 926 primary_key=self._primary_key, 927 schema=self._schema, 928 alias=self._alias, 929 _model=self._model, 930 _database=self._database) 931 932 def bind(self, database=None): 933 self._database = database 934 return self 935 936 def bind_ctx(self, database=None): 937 return _BoundTableContext(self, database) 938 939 def _get_hash(self): 940 return hash((self.__class__, self._path, self._alias, self._model)) 941 942 @__bind_database__ 943 def select(self, *columns): 944 if not columns and self._columns: 945 columns = [Column(self, column) for column in self._columns] 946 return Select((self,), columns) 947 948 @__bind_database__ 949 def insert(self, insert=None, columns=None, **kwargs): 950 if kwargs: 951 insert = {} if insert is None else insert 952 src = self if self._columns else self.c 953 for key, value in kwargs.items(): 954 insert[getattr(src, key)] = value 955 return Insert(self, insert=insert, columns=columns) 956 957 @__bind_database__ 958 def replace(self, insert=None, columns=None, **kwargs): 959 return (self 960 .insert(insert=insert, columns=columns) 961 .on_conflict('REPLACE')) 962 963 @__bind_database__ 964 def update(self, update=None, **kwargs): 965 if kwargs: 966 update = {} if update is None else update 967 for key, value in kwargs.items(): 968 src = self if self._columns else self.c 969 update[getattr(src, key)] = value 970 return Update(self, update=update) 971 972 @__bind_database__ 973 def delete(self): 974 return Delete(self) 975 976 def __sql__(self, ctx): 977 if ctx.scope == SCOPE_VALUES: 978 # Return the quoted table name. 979 return ctx.sql(Entity(*self._path)) 980 981 if self._alias: 982 ctx.alias_manager[self] = self._alias 983 984 if ctx.scope == SCOPE_SOURCE: 985 # Define the table and its alias. 986 return self.apply_alias(ctx.sql(Entity(*self._path))) 987 else: 988 # Refer to the table using the alias. 989 return self.apply_column(ctx) 990 991 992class Join(BaseTable): 993 def __init__(self, lhs, rhs, join_type=JOIN.INNER, on=None, alias=None): 994 super(Join, self).__init__(alias=alias) 995 self.lhs = lhs 996 self.rhs = rhs 997 self.join_type = join_type 998 self._on = on 999 1000 def on(self, predicate): 1001 self._on = predicate 1002 return self 1003 1004 def __sql__(self, ctx): 1005 (ctx 1006 .sql(self.lhs) 1007 .literal(' %s ' % self.join_type) 1008 .sql(self.rhs)) 1009 if self._on is not None: 1010 ctx.literal(' ON ').sql(self._on) 1011 return ctx 1012 1013 1014class ValuesList(_HashableSource, BaseTable): 1015 def __init__(self, values, columns=None, alias=None): 1016 self._values = values 1017 self._columns = columns 1018 super(ValuesList, self).__init__(alias=alias) 1019 1020 def _get_hash(self): 1021 return hash((self.__class__, id(self._values), self._alias)) 1022 1023 @Node.copy 1024 def columns(self, *names): 1025 self._columns = names 1026 1027 def __sql__(self, ctx): 1028 if self._alias: 1029 ctx.alias_manager[self] = self._alias 1030 1031 if ctx.scope == SCOPE_SOURCE or ctx.scope == SCOPE_NORMAL: 1032 with ctx(parentheses=not ctx.parentheses): 1033 ctx = (ctx 1034 .literal('VALUES ') 1035 .sql(CommaNodeList([ 1036 EnclosedNodeList(row) for row in self._values]))) 1037 1038 if ctx.scope == SCOPE_SOURCE: 1039 ctx.literal(' AS ').sql(Entity(ctx.alias_manager[self])) 1040 if self._columns: 1041 entities = [Entity(c) for c in self._columns] 1042 ctx.sql(EnclosedNodeList(entities)) 1043 else: 1044 ctx.sql(Entity(ctx.alias_manager[self])) 1045 1046 return ctx 1047 1048 1049class CTE(_HashableSource, Source): 1050 def __init__(self, name, query, recursive=False, columns=None, 1051 materialized=None): 1052 self._alias = name 1053 self._query = query 1054 self._recursive = recursive 1055 self._materialized = materialized 1056 if columns is not None: 1057 columns = [Entity(c) if isinstance(c, basestring) else c 1058 for c in columns] 1059 self._columns = columns 1060 query._cte_list = () 1061 super(CTE, self).__init__(alias=name) 1062 1063 def select_from(self, *columns): 1064 if not columns: 1065 raise ValueError('select_from() must specify one or more columns ' 1066 'from the CTE to select.') 1067 1068 query = (Select((self,), columns) 1069 .with_cte(self) 1070 .bind(self._query._database)) 1071 try: 1072 query = query.objects(self._query.model) 1073 except AttributeError: 1074 pass 1075 return query 1076 1077 def _get_hash(self): 1078 return hash((self.__class__, self._alias, id(self._query))) 1079 1080 def union_all(self, rhs): 1081 clone = self._query.clone() 1082 return CTE(self._alias, clone + rhs, self._recursive, self._columns) 1083 __add__ = union_all 1084 1085 def union(self, rhs): 1086 clone = self._query.clone() 1087 return CTE(self._alias, clone | rhs, self._recursive, self._columns) 1088 __or__ = union 1089 1090 def __sql__(self, ctx): 1091 if ctx.scope != SCOPE_CTE: 1092 return ctx.sql(Entity(self._alias)) 1093 1094 with ctx.push_alias(): 1095 ctx.alias_manager[self] = self._alias 1096 ctx.sql(Entity(self._alias)) 1097 1098 if self._columns: 1099 ctx.literal(' ').sql(EnclosedNodeList(self._columns)) 1100 ctx.literal(' AS ') 1101 1102 if self._materialized: 1103 ctx.literal('MATERIALIZED ') 1104 elif self._materialized is False: 1105 ctx.literal('NOT MATERIALIZED ') 1106 1107 with ctx.scope_normal(parentheses=True): 1108 ctx.sql(self._query) 1109 return ctx 1110 1111 1112class ColumnBase(Node): 1113 _converter = None 1114 1115 @Node.copy 1116 def converter(self, converter=None): 1117 self._converter = converter 1118 1119 def alias(self, alias): 1120 if alias: 1121 return Alias(self, alias) 1122 return self 1123 1124 def unalias(self): 1125 return self 1126 1127 def cast(self, as_type): 1128 return Cast(self, as_type) 1129 1130 def asc(self, collation=None, nulls=None): 1131 return Asc(self, collation=collation, nulls=nulls) 1132 __pos__ = asc 1133 1134 def desc(self, collation=None, nulls=None): 1135 return Desc(self, collation=collation, nulls=nulls) 1136 __neg__ = desc 1137 1138 def __invert__(self): 1139 return Negated(self) 1140 1141 def _e(op, inv=False): 1142 """ 1143 Lightweight factory which returns a method that builds an Expression 1144 consisting of the left-hand and right-hand operands, using `op`. 1145 """ 1146 def inner(self, rhs): 1147 if inv: 1148 return Expression(rhs, op, self) 1149 return Expression(self, op, rhs) 1150 return inner 1151 __and__ = _e(OP.AND) 1152 __or__ = _e(OP.OR) 1153 1154 __add__ = _e(OP.ADD) 1155 __sub__ = _e(OP.SUB) 1156 __mul__ = _e(OP.MUL) 1157 __div__ = __truediv__ = _e(OP.DIV) 1158 __xor__ = _e(OP.XOR) 1159 __radd__ = _e(OP.ADD, inv=True) 1160 __rsub__ = _e(OP.SUB, inv=True) 1161 __rmul__ = _e(OP.MUL, inv=True) 1162 __rdiv__ = __rtruediv__ = _e(OP.DIV, inv=True) 1163 __rand__ = _e(OP.AND, inv=True) 1164 __ror__ = _e(OP.OR, inv=True) 1165 __rxor__ = _e(OP.XOR, inv=True) 1166 1167 def __eq__(self, rhs): 1168 op = OP.IS if rhs is None else OP.EQ 1169 return Expression(self, op, rhs) 1170 def __ne__(self, rhs): 1171 op = OP.IS_NOT if rhs is None else OP.NE 1172 return Expression(self, op, rhs) 1173 1174 __lt__ = _e(OP.LT) 1175 __le__ = _e(OP.LTE) 1176 __gt__ = _e(OP.GT) 1177 __ge__ = _e(OP.GTE) 1178 __lshift__ = _e(OP.IN) 1179 __rshift__ = _e(OP.IS) 1180 __mod__ = _e(OP.LIKE) 1181 __pow__ = _e(OP.ILIKE) 1182 1183 bin_and = _e(OP.BIN_AND) 1184 bin_or = _e(OP.BIN_OR) 1185 in_ = _e(OP.IN) 1186 not_in = _e(OP.NOT_IN) 1187 regexp = _e(OP.REGEXP) 1188 1189 # Special expressions. 1190 def is_null(self, is_null=True): 1191 op = OP.IS if is_null else OP.IS_NOT 1192 return Expression(self, op, None) 1193 1194 def _escape_like_expr(self, s, template): 1195 if s.find('_') >= 0 or s.find('%') >= 0 or s.find('\\') >= 0: 1196 s = s.replace('\\', '\\\\').replace('_', '\\_').replace('%', '\\%') 1197 return NodeList((template % s, SQL('ESCAPE'), '\\')) 1198 return template % s 1199 def contains(self, rhs): 1200 if isinstance(rhs, Node): 1201 rhs = Expression('%', OP.CONCAT, 1202 Expression(rhs, OP.CONCAT, '%')) 1203 else: 1204 rhs = self._escape_like_expr(rhs, '%%%s%%') 1205 return Expression(self, OP.ILIKE, rhs) 1206 def startswith(self, rhs): 1207 if isinstance(rhs, Node): 1208 rhs = Expression(rhs, OP.CONCAT, '%') 1209 else: 1210 rhs = self._escape_like_expr(rhs, '%s%%') 1211 return Expression(self, OP.ILIKE, rhs) 1212 def endswith(self, rhs): 1213 if isinstance(rhs, Node): 1214 rhs = Expression('%', OP.CONCAT, rhs) 1215 else: 1216 rhs = self._escape_like_expr(rhs, '%%%s') 1217 return Expression(self, OP.ILIKE, rhs) 1218 def between(self, lo, hi): 1219 return Expression(self, OP.BETWEEN, NodeList((lo, SQL('AND'), hi))) 1220 def concat(self, rhs): 1221 return StringExpression(self, OP.CONCAT, rhs) 1222 def regexp(self, rhs): 1223 return Expression(self, OP.REGEXP, rhs) 1224 def iregexp(self, rhs): 1225 return Expression(self, OP.IREGEXP, rhs) 1226 def __getitem__(self, item): 1227 if isinstance(item, slice): 1228 if item.start is None or item.stop is None: 1229 raise ValueError('BETWEEN range must have both a start- and ' 1230 'end-point.') 1231 return self.between(item.start, item.stop) 1232 return self == item 1233 1234 def distinct(self): 1235 return NodeList((SQL('DISTINCT'), self)) 1236 1237 def collate(self, collation): 1238 return NodeList((self, SQL('COLLATE %s' % collation))) 1239 1240 def get_sort_key(self, ctx): 1241 return () 1242 1243 1244class Column(ColumnBase): 1245 def __init__(self, source, name): 1246 self.source = source 1247 self.name = name 1248 1249 def get_sort_key(self, ctx): 1250 if ctx.scope == SCOPE_VALUES: 1251 return (self.name,) 1252 else: 1253 return self.source.get_sort_key(ctx) + (self.name,) 1254 1255 def __hash__(self): 1256 return hash((self.source, self.name)) 1257 1258 def __sql__(self, ctx): 1259 if ctx.scope == SCOPE_VALUES: 1260 return ctx.sql(Entity(self.name)) 1261 else: 1262 with ctx.scope_column(): 1263 return ctx.sql(self.source).literal('.').sql(Entity(self.name)) 1264 1265 1266class WrappedNode(ColumnBase): 1267 def __init__(self, node): 1268 self.node = node 1269 self._coerce = getattr(node, '_coerce', True) 1270 self._converter = getattr(node, '_converter', None) 1271 1272 def is_alias(self): 1273 return self.node.is_alias() 1274 1275 def unwrap(self): 1276 return self.node.unwrap() 1277 1278 1279class EntityFactory(object): 1280 __slots__ = ('node',) 1281 def __init__(self, node): 1282 self.node = node 1283 def __getattr__(self, attr): 1284 return Entity(self.node, attr) 1285 1286 1287class _DynamicEntity(object): 1288 __slots__ = () 1289 def __get__(self, instance, instance_type=None): 1290 if instance is not None: 1291 return EntityFactory(instance._alias) # Implements __getattr__(). 1292 return self 1293 1294 1295class Alias(WrappedNode): 1296 c = _DynamicEntity() 1297 1298 def __init__(self, node, alias): 1299 super(Alias, self).__init__(node) 1300 self._alias = alias 1301 1302 def __hash__(self): 1303 return hash(self._alias) 1304 1305 def alias(self, alias=None): 1306 if alias is None: 1307 return self.node 1308 else: 1309 return Alias(self.node, alias) 1310 1311 def unalias(self): 1312 return self.node 1313 1314 def is_alias(self): 1315 return True 1316 1317 def __sql__(self, ctx): 1318 if ctx.scope == SCOPE_SOURCE: 1319 return (ctx 1320 .sql(self.node) 1321 .literal(' AS ') 1322 .sql(Entity(self._alias))) 1323 else: 1324 return ctx.sql(Entity(self._alias)) 1325 1326 1327class Negated(WrappedNode): 1328 def __invert__(self): 1329 return self.node 1330 1331 def __sql__(self, ctx): 1332 return ctx.literal('NOT ').sql(self.node) 1333 1334 1335class BitwiseMixin(object): 1336 def __and__(self, other): 1337 return self.bin_and(other) 1338 1339 def __or__(self, other): 1340 return self.bin_or(other) 1341 1342 def __sub__(self, other): 1343 return self.bin_and(other.bin_negated()) 1344 1345 def __invert__(self): 1346 return BitwiseNegated(self) 1347 1348 1349class BitwiseNegated(BitwiseMixin, WrappedNode): 1350 def __invert__(self): 1351 return self.node 1352 1353 def __sql__(self, ctx): 1354 if ctx.state.operations: 1355 op_sql = ctx.state.operations.get(self.op, self.op) 1356 else: 1357 op_sql = self.op 1358 return ctx.literal(op_sql).sql(self.node) 1359 1360 1361class Value(ColumnBase): 1362 def __init__(self, value, converter=None, unpack=True): 1363 self.value = value 1364 self.converter = converter 1365 self.multi = unpack and isinstance(self.value, multi_types) 1366 if self.multi: 1367 self.values = [] 1368 for item in self.value: 1369 if isinstance(item, Node): 1370 self.values.append(item) 1371 else: 1372 self.values.append(Value(item, self.converter)) 1373 1374 def __sql__(self, ctx): 1375 if self.multi: 1376 # For multi-part values (e.g. lists of IDs). 1377 return ctx.sql(EnclosedNodeList(self.values)) 1378 1379 return ctx.value(self.value, self.converter) 1380 1381 1382def AsIs(value): 1383 return Value(value, unpack=False) 1384 1385 1386class Cast(WrappedNode): 1387 def __init__(self, node, cast): 1388 super(Cast, self).__init__(node) 1389 self._cast = cast 1390 self._coerce = False 1391 1392 def __sql__(self, ctx): 1393 return (ctx 1394 .literal('CAST(') 1395 .sql(self.node) 1396 .literal(' AS %s)' % self._cast)) 1397 1398 1399class Ordering(WrappedNode): 1400 def __init__(self, node, direction, collation=None, nulls=None): 1401 super(Ordering, self).__init__(node) 1402 self.direction = direction 1403 self.collation = collation 1404 self.nulls = nulls 1405 if nulls and nulls.lower() not in ('first', 'last'): 1406 raise ValueError('Ordering nulls= parameter must be "first" or ' 1407 '"last", got: %s' % nulls) 1408 1409 def collate(self, collation=None): 1410 return Ordering(self.node, self.direction, collation) 1411 1412 def _null_ordering_case(self, nulls): 1413 if nulls.lower() == 'last': 1414 ifnull, notnull = 1, 0 1415 elif nulls.lower() == 'first': 1416 ifnull, notnull = 0, 1 1417 else: 1418 raise ValueError('unsupported value for nulls= ordering.') 1419 return Case(None, ((self.node.is_null(), ifnull),), notnull) 1420 1421 def __sql__(self, ctx): 1422 if self.nulls and not ctx.state.nulls_ordering: 1423 ctx.sql(self._null_ordering_case(self.nulls)).literal(', ') 1424 1425 ctx.sql(self.node).literal(' %s' % self.direction) 1426 if self.collation: 1427 ctx.literal(' COLLATE %s' % self.collation) 1428 if self.nulls and ctx.state.nulls_ordering: 1429 ctx.literal(' NULLS %s' % self.nulls) 1430 return ctx 1431 1432 1433def Asc(node, collation=None, nulls=None): 1434 return Ordering(node, 'ASC', collation, nulls) 1435 1436 1437def Desc(node, collation=None, nulls=None): 1438 return Ordering(node, 'DESC', collation, nulls) 1439 1440 1441class Expression(ColumnBase): 1442 def __init__(self, lhs, op, rhs, flat=False): 1443 self.lhs = lhs 1444 self.op = op 1445 self.rhs = rhs 1446 self.flat = flat 1447 1448 def __sql__(self, ctx): 1449 overrides = {'parentheses': not self.flat, 'in_expr': True} 1450 1451 # First attempt to unwrap the node on the left-hand-side, so that we 1452 # can get at the underlying Field if one is present. 1453 node = raw_node = self.lhs 1454 if isinstance(raw_node, WrappedNode): 1455 node = raw_node.unwrap() 1456 1457 # Set up the appropriate converter if we have a field on the left side. 1458 if isinstance(node, Field) and raw_node._coerce: 1459 overrides['converter'] = node.db_value 1460 overrides['is_fk_expr'] = isinstance(node, ForeignKeyField) 1461 else: 1462 overrides['converter'] = None 1463 1464 if ctx.state.operations: 1465 op_sql = ctx.state.operations.get(self.op, self.op) 1466 else: 1467 op_sql = self.op 1468 1469 with ctx(**overrides): 1470 # Postgresql reports an error for IN/NOT IN (), so convert to 1471 # the equivalent boolean expression. 1472 op_in = self.op == OP.IN or self.op == OP.NOT_IN 1473 if op_in and ctx.as_new().parse(self.rhs)[0] == '()': 1474 return ctx.literal('0 = 1' if self.op == OP.IN else '1 = 1') 1475 1476 return (ctx 1477 .sql(self.lhs) 1478 .literal(' %s ' % op_sql) 1479 .sql(self.rhs)) 1480 1481 1482class StringExpression(Expression): 1483 def __add__(self, rhs): 1484 return self.concat(rhs) 1485 def __radd__(self, lhs): 1486 return StringExpression(lhs, OP.CONCAT, self) 1487 1488 1489class Entity(ColumnBase): 1490 def __init__(self, *path): 1491 self._path = [part.replace('"', '""') for part in path if part] 1492 1493 def __getattr__(self, attr): 1494 return Entity(*self._path + [attr]) 1495 1496 def get_sort_key(self, ctx): 1497 return tuple(self._path) 1498 1499 def __hash__(self): 1500 return hash((self.__class__.__name__, tuple(self._path))) 1501 1502 def __sql__(self, ctx): 1503 return ctx.literal(quote(self._path, ctx.state.quote or '""')) 1504 1505 1506class SQL(ColumnBase): 1507 def __init__(self, sql, params=None): 1508 self.sql = sql 1509 self.params = params 1510 1511 def __sql__(self, ctx): 1512 ctx.literal(self.sql) 1513 if self.params: 1514 for param in self.params: 1515 ctx.value(param, False, add_param=False) 1516 return ctx 1517 1518 1519def Check(constraint, name=None): 1520 check = SQL('CHECK (%s)' % constraint) 1521 if not name: 1522 return check 1523 return NodeList((SQL('CONSTRAINT'), Entity(name), check)) 1524 1525 1526class Function(ColumnBase): 1527 def __init__(self, name, arguments, coerce=True, python_value=None): 1528 self.name = name 1529 self.arguments = arguments 1530 self._filter = None 1531 self._order_by = None 1532 self._python_value = python_value 1533 if name and name.lower() in ('sum', 'count', 'cast', 'array_agg'): 1534 self._coerce = False 1535 else: 1536 self._coerce = coerce 1537 1538 def __getattr__(self, attr): 1539 def decorator(*args, **kwargs): 1540 return Function(attr, args, **kwargs) 1541 return decorator 1542 1543 @Node.copy 1544 def filter(self, where=None): 1545 self._filter = where 1546 1547 @Node.copy 1548 def order_by(self, *ordering): 1549 self._order_by = ordering 1550 1551 @Node.copy 1552 def python_value(self, func=None): 1553 self._python_value = func 1554 1555 def over(self, partition_by=None, order_by=None, start=None, end=None, 1556 frame_type=None, window=None, exclude=None): 1557 if isinstance(partition_by, Window) and window is None: 1558 window = partition_by 1559 1560 if window is not None: 1561 node = WindowAlias(window) 1562 else: 1563 node = Window(partition_by=partition_by, order_by=order_by, 1564 start=start, end=end, frame_type=frame_type, 1565 exclude=exclude, _inline=True) 1566 return NodeList((self, SQL('OVER'), node)) 1567 1568 def __sql__(self, ctx): 1569 ctx.literal(self.name) 1570 if not len(self.arguments): 1571 ctx.literal('()') 1572 else: 1573 args = self.arguments 1574 1575 # If this is an ordered aggregate, then we will modify the last 1576 # argument to append the ORDER BY ... clause. We do this to avoid 1577 # double-wrapping any expression args in parentheses, as NodeList 1578 # has a special check (hack) in place to work around this. 1579 if self._order_by: 1580 args = list(args) 1581 args[-1] = NodeList((args[-1], SQL('ORDER BY'), 1582 CommaNodeList(self._order_by))) 1583 1584 with ctx(in_function=True, function_arg_count=len(self.arguments)): 1585 ctx.sql(EnclosedNodeList([ 1586 (arg if isinstance(arg, Node) else Value(arg, False)) 1587 for arg in args])) 1588 1589 if self._filter: 1590 ctx.literal(' FILTER (WHERE ').sql(self._filter).literal(')') 1591 return ctx 1592 1593 1594fn = Function(None, None) 1595 1596 1597class Window(Node): 1598 # Frame start/end and frame exclusion. 1599 CURRENT_ROW = SQL('CURRENT ROW') 1600 GROUP = SQL('GROUP') 1601 TIES = SQL('TIES') 1602 NO_OTHERS = SQL('NO OTHERS') 1603 1604 # Frame types. 1605 GROUPS = 'GROUPS' 1606 RANGE = 'RANGE' 1607 ROWS = 'ROWS' 1608 1609 def __init__(self, partition_by=None, order_by=None, start=None, end=None, 1610 frame_type=None, extends=None, exclude=None, alias=None, 1611 _inline=False): 1612 super(Window, self).__init__() 1613 if start is not None and not isinstance(start, SQL): 1614 start = SQL(start) 1615 if end is not None and not isinstance(end, SQL): 1616 end = SQL(end) 1617 1618 self.partition_by = ensure_tuple(partition_by) 1619 self.order_by = ensure_tuple(order_by) 1620 self.start = start 1621 self.end = end 1622 if self.start is None and self.end is not None: 1623 raise ValueError('Cannot specify WINDOW end without start.') 1624 self._alias = alias or 'w' 1625 self._inline = _inline 1626 self.frame_type = frame_type 1627 self._extends = extends 1628 self._exclude = exclude 1629 1630 def alias(self, alias=None): 1631 self._alias = alias or 'w' 1632 return self 1633 1634 @Node.copy 1635 def as_range(self): 1636 self.frame_type = Window.RANGE 1637 1638 @Node.copy 1639 def as_rows(self): 1640 self.frame_type = Window.ROWS 1641 1642 @Node.copy 1643 def as_groups(self): 1644 self.frame_type = Window.GROUPS 1645 1646 @Node.copy 1647 def extends(self, window=None): 1648 self._extends = window 1649 1650 @Node.copy 1651 def exclude(self, frame_exclusion=None): 1652 if isinstance(frame_exclusion, basestring): 1653 frame_exclusion = SQL(frame_exclusion) 1654 self._exclude = frame_exclusion 1655 1656 @staticmethod 1657 def following(value=None): 1658 if value is None: 1659 return SQL('UNBOUNDED FOLLOWING') 1660 return SQL('%d FOLLOWING' % value) 1661 1662 @staticmethod 1663 def preceding(value=None): 1664 if value is None: 1665 return SQL('UNBOUNDED PRECEDING') 1666 return SQL('%d PRECEDING' % value) 1667 1668 def __sql__(self, ctx): 1669 if ctx.scope != SCOPE_SOURCE and not self._inline: 1670 ctx.literal(self._alias) 1671 ctx.literal(' AS ') 1672 1673 with ctx(parentheses=True): 1674 parts = [] 1675 if self._extends is not None: 1676 ext = self._extends 1677 if isinstance(ext, Window): 1678 ext = SQL(ext._alias) 1679 elif isinstance(ext, basestring): 1680 ext = SQL(ext) 1681 parts.append(ext) 1682 if self.partition_by: 1683 parts.extend(( 1684 SQL('PARTITION BY'), 1685 CommaNodeList(self.partition_by))) 1686 if self.order_by: 1687 parts.extend(( 1688 SQL('ORDER BY'), 1689 CommaNodeList(self.order_by))) 1690 if self.start is not None and self.end is not None: 1691 frame = self.frame_type or 'ROWS' 1692 parts.extend(( 1693 SQL('%s BETWEEN' % frame), 1694 self.start, 1695 SQL('AND'), 1696 self.end)) 1697 elif self.start is not None: 1698 parts.extend((SQL(self.frame_type or 'ROWS'), self.start)) 1699 elif self.frame_type is not None: 1700 parts.append(SQL('%s UNBOUNDED PRECEDING' % self.frame_type)) 1701 if self._exclude is not None: 1702 parts.extend((SQL('EXCLUDE'), self._exclude)) 1703 ctx.sql(NodeList(parts)) 1704 return ctx 1705 1706 1707class WindowAlias(Node): 1708 def __init__(self, window): 1709 self.window = window 1710 1711 def alias(self, window_alias): 1712 self.window._alias = window_alias 1713 return self 1714 1715 def __sql__(self, ctx): 1716 return ctx.literal(self.window._alias or 'w') 1717 1718 1719class ForUpdate(Node): 1720 def __init__(self, expr, of=None, nowait=None): 1721 expr = 'FOR UPDATE' if expr is True else expr 1722 if expr.lower().endswith('nowait'): 1723 expr = expr[:-7] # Strip off the "nowait" bit. 1724 nowait = True 1725 1726 self._expr = expr 1727 if of is not None and not isinstance(of, (list, set, tuple)): 1728 of = (of,) 1729 self._of = of 1730 self._nowait = nowait 1731 1732 def __sql__(self, ctx): 1733 ctx.literal(self._expr) 1734 if self._of is not None: 1735 ctx.literal(' OF ').sql(CommaNodeList(self._of)) 1736 if self._nowait: 1737 ctx.literal(' NOWAIT') 1738 return ctx 1739 1740 1741def Case(predicate, expression_tuples, default=None): 1742 clauses = [SQL('CASE')] 1743 if predicate is not None: 1744 clauses.append(predicate) 1745 for expr, value in expression_tuples: 1746 clauses.extend((SQL('WHEN'), expr, SQL('THEN'), value)) 1747 if default is not None: 1748 clauses.extend((SQL('ELSE'), default)) 1749 clauses.append(SQL('END')) 1750 return NodeList(clauses) 1751 1752 1753class NodeList(ColumnBase): 1754 def __init__(self, nodes, glue=' ', parens=False): 1755 self.nodes = nodes 1756 self.glue = glue 1757 self.parens = parens 1758 if parens and len(self.nodes) == 1 and \ 1759 isinstance(self.nodes[0], Expression) and \ 1760 not self.nodes[0].flat: 1761 # Hack to avoid double-parentheses. 1762 self.nodes = (self.nodes[0].clone(),) 1763 self.nodes[0].flat = True 1764 1765 def __sql__(self, ctx): 1766 n_nodes = len(self.nodes) 1767 if n_nodes == 0: 1768 return ctx.literal('()') if self.parens else ctx 1769 with ctx(parentheses=self.parens): 1770 for i in range(n_nodes - 1): 1771 ctx.sql(self.nodes[i]) 1772 ctx.literal(self.glue) 1773 ctx.sql(self.nodes[n_nodes - 1]) 1774 return ctx 1775 1776 1777def CommaNodeList(nodes): 1778 return NodeList(nodes, ', ') 1779 1780 1781def EnclosedNodeList(nodes): 1782 return NodeList(nodes, ', ', True) 1783 1784 1785class _Namespace(Node): 1786 __slots__ = ('_name',) 1787 def __init__(self, name): 1788 self._name = name 1789 def __getattr__(self, attr): 1790 return NamespaceAttribute(self, attr) 1791 __getitem__ = __getattr__ 1792 1793class NamespaceAttribute(ColumnBase): 1794 def __init__(self, namespace, attribute): 1795 self._namespace = namespace 1796 self._attribute = attribute 1797 1798 def __sql__(self, ctx): 1799 return (ctx 1800 .literal(self._namespace._name + '.') 1801 .sql(Entity(self._attribute))) 1802 1803EXCLUDED = _Namespace('EXCLUDED') 1804 1805 1806class DQ(ColumnBase): 1807 def __init__(self, **query): 1808 super(DQ, self).__init__() 1809 self.query = query 1810 self._negated = False 1811 1812 @Node.copy 1813 def __invert__(self): 1814 self._negated = not self._negated 1815 1816 def clone(self): 1817 node = DQ(**self.query) 1818 node._negated = self._negated 1819 return node 1820 1821#: Represent a row tuple. 1822Tuple = lambda *a: EnclosedNodeList(a) 1823 1824 1825class QualifiedNames(WrappedNode): 1826 def __sql__(self, ctx): 1827 with ctx.scope_column(): 1828 return ctx.sql(self.node) 1829 1830 1831def qualify_names(node): 1832 # Search a node heirarchy to ensure that any column-like objects are 1833 # referenced using fully-qualified names. 1834 if isinstance(node, Expression): 1835 return node.__class__(qualify_names(node.lhs), node.op, 1836 qualify_names(node.rhs), node.flat) 1837 elif isinstance(node, ColumnBase): 1838 return QualifiedNames(node) 1839 return node 1840 1841 1842class OnConflict(Node): 1843 def __init__(self, action=None, update=None, preserve=None, where=None, 1844 conflict_target=None, conflict_where=None, 1845 conflict_constraint=None): 1846 self._action = action 1847 self._update = update 1848 self._preserve = ensure_tuple(preserve) 1849 self._where = where 1850 if conflict_target is not None and conflict_constraint is not None: 1851 raise ValueError('only one of "conflict_target" and ' 1852 '"conflict_constraint" may be specified.') 1853 self._conflict_target = ensure_tuple(conflict_target) 1854 self._conflict_where = conflict_where 1855 self._conflict_constraint = conflict_constraint 1856 1857 def get_conflict_statement(self, ctx, query): 1858 return ctx.state.conflict_statement(self, query) 1859 1860 def get_conflict_update(self, ctx, query): 1861 return ctx.state.conflict_update(self, query) 1862 1863 @Node.copy 1864 def preserve(self, *columns): 1865 self._preserve = columns 1866 1867 @Node.copy 1868 def update(self, _data=None, **kwargs): 1869 if _data and kwargs and not isinstance(_data, dict): 1870 raise ValueError('Cannot mix data with keyword arguments in the ' 1871 'OnConflict update method.') 1872 _data = _data or {} 1873 if kwargs: 1874 _data.update(kwargs) 1875 self._update = _data 1876 1877 @Node.copy 1878 def where(self, *expressions): 1879 if self._where is not None: 1880 expressions = (self._where,) + expressions 1881 self._where = reduce(operator.and_, expressions) 1882 1883 @Node.copy 1884 def conflict_target(self, *constraints): 1885 self._conflict_constraint = None 1886 self._conflict_target = constraints 1887 1888 @Node.copy 1889 def conflict_where(self, *expressions): 1890 if self._conflict_where is not None: 1891 expressions = (self._conflict_where,) + expressions 1892 self._conflict_where = reduce(operator.and_, expressions) 1893 1894 @Node.copy 1895 def conflict_constraint(self, constraint): 1896 self._conflict_constraint = constraint 1897 self._conflict_target = None 1898 1899 1900def database_required(method): 1901 @wraps(method) 1902 def inner(self, database=None, *args, **kwargs): 1903 database = self._database if database is None else database 1904 if not database: 1905 raise InterfaceError('Query must be bound to a database in order ' 1906 'to call "%s".' % method.__name__) 1907 return method(self, database, *args, **kwargs) 1908 return inner 1909 1910# BASE QUERY INTERFACE. 1911 1912class BaseQuery(Node): 1913 default_row_type = ROW.DICT 1914 1915 def __init__(self, _database=None, **kwargs): 1916 self._database = _database 1917 self._cursor_wrapper = None 1918 self._row_type = None 1919 self._constructor = None 1920 super(BaseQuery, self).__init__(**kwargs) 1921 1922 def bind(self, database=None): 1923 self._database = database 1924 return self 1925 1926 def clone(self): 1927 query = super(BaseQuery, self).clone() 1928 query._cursor_wrapper = None 1929 return query 1930 1931 @Node.copy 1932 def dicts(self, as_dict=True): 1933 self._row_type = ROW.DICT if as_dict else None 1934 return self 1935 1936 @Node.copy 1937 def tuples(self, as_tuple=True): 1938 self._row_type = ROW.TUPLE if as_tuple else None 1939 return self 1940 1941 @Node.copy 1942 def namedtuples(self, as_namedtuple=True): 1943 self._row_type = ROW.NAMED_TUPLE if as_namedtuple else None 1944 return self 1945 1946 @Node.copy 1947 def objects(self, constructor=None): 1948 self._row_type = ROW.CONSTRUCTOR if constructor else None 1949 self._constructor = constructor 1950 return self 1951 1952 def _get_cursor_wrapper(self, cursor): 1953 row_type = self._row_type or self.default_row_type 1954 1955 if row_type == ROW.DICT: 1956 return DictCursorWrapper(cursor) 1957 elif row_type == ROW.TUPLE: 1958 return CursorWrapper(cursor) 1959 elif row_type == ROW.NAMED_TUPLE: 1960 return NamedTupleCursorWrapper(cursor) 1961 elif row_type == ROW.CONSTRUCTOR: 1962 return ObjectCursorWrapper(cursor, self._constructor) 1963 else: 1964 raise ValueError('Unrecognized row type: "%s".' % row_type) 1965 1966 def __sql__(self, ctx): 1967 raise NotImplementedError 1968 1969 def sql(self): 1970 if self._database: 1971 context = self._database.get_sql_context() 1972 else: 1973 context = Context() 1974 return context.parse(self) 1975 1976 @database_required 1977 def execute(self, database): 1978 return self._execute(database) 1979 1980 def _execute(self, database): 1981 raise NotImplementedError 1982 1983 def iterator(self, database=None): 1984 return iter(self.execute(database).iterator()) 1985 1986 def _ensure_execution(self): 1987 if not self._cursor_wrapper: 1988 if not self._database: 1989 raise ValueError('Query has not been executed.') 1990 self.execute() 1991 1992 def __iter__(self): 1993 self._ensure_execution() 1994 return iter(self._cursor_wrapper) 1995 1996 def __getitem__(self, value): 1997 self._ensure_execution() 1998 if isinstance(value, slice): 1999 index = value.stop 2000 else: 2001 index = value 2002 if index is not None: 2003 index = index + 1 if index >= 0 else 0 2004 self._cursor_wrapper.fill_cache(index) 2005 return self._cursor_wrapper.row_cache[value] 2006 2007 def __len__(self): 2008 self._ensure_execution() 2009 return len(self._cursor_wrapper) 2010 2011 def __str__(self): 2012 return query_to_string(self) 2013 2014 2015class RawQuery(BaseQuery): 2016 def __init__(self, sql=None, params=None, **kwargs): 2017 super(RawQuery, self).__init__(**kwargs) 2018 self._sql = sql 2019 self._params = params 2020 2021 def __sql__(self, ctx): 2022 ctx.literal(self._sql) 2023 if self._params: 2024 for param in self._params: 2025 ctx.value(param, add_param=False) 2026 return ctx 2027 2028 def _execute(self, database): 2029 if self._cursor_wrapper is None: 2030 cursor = database.execute(self) 2031 self._cursor_wrapper = self._get_cursor_wrapper(cursor) 2032 return self._cursor_wrapper 2033 2034 2035class Query(BaseQuery): 2036 def __init__(self, where=None, order_by=None, limit=None, offset=None, 2037 **kwargs): 2038 super(Query, self).__init__(**kwargs) 2039 self._where = where 2040 self._order_by = order_by 2041 self._limit = limit 2042 self._offset = offset 2043 2044 self._cte_list = None 2045 2046 @Node.copy 2047 def with_cte(self, *cte_list): 2048 self._cte_list = cte_list 2049 2050 @Node.copy 2051 def where(self, *expressions): 2052 if self._where is not None: 2053 expressions = (self._where,) + expressions 2054 self._where = reduce(operator.and_, expressions) 2055 2056 @Node.copy 2057 def orwhere(self, *expressions): 2058 if self._where is not None: 2059 expressions = (self._where,) + expressions 2060 self._where = reduce(operator.or_, expressions) 2061 2062 @Node.copy 2063 def order_by(self, *values): 2064 self._order_by = values 2065 2066 @Node.copy 2067 def order_by_extend(self, *values): 2068 self._order_by = ((self._order_by or ()) + values) or None 2069 2070 @Node.copy 2071 def limit(self, value=None): 2072 self._limit = value 2073 2074 @Node.copy 2075 def offset(self, value=None): 2076 self._offset = value 2077 2078 @Node.copy 2079 def paginate(self, page, paginate_by=20): 2080 if page > 0: 2081 page -= 1 2082 self._limit = paginate_by 2083 self._offset = page * paginate_by 2084 2085 def _apply_ordering(self, ctx): 2086 if self._order_by: 2087 (ctx 2088 .literal(' ORDER BY ') 2089 .sql(CommaNodeList(self._order_by))) 2090 if self._limit is not None or (self._offset is not None and 2091 ctx.state.limit_max): 2092 limit = ctx.state.limit_max if self._limit is None else self._limit 2093 ctx.literal(' LIMIT ').sql(limit) 2094 if self._offset is not None: 2095 ctx.literal(' OFFSET ').sql(self._offset) 2096 return ctx 2097 2098 def __sql__(self, ctx): 2099 if self._cte_list: 2100 # The CTE scope is only used at the very beginning of the query, 2101 # when we are describing the various CTEs we will be using. 2102 recursive = any(cte._recursive for cte in self._cte_list) 2103 2104 # Explicitly disable the "subquery" flag here, so as to avoid 2105 # unnecessary parentheses around subsequent selects. 2106 with ctx.scope_cte(subquery=False): 2107 (ctx 2108 .literal('WITH RECURSIVE ' if recursive else 'WITH ') 2109 .sql(CommaNodeList(self._cte_list)) 2110 .literal(' ')) 2111 return ctx 2112 2113 2114def __compound_select__(operation, inverted=False): 2115 def method(self, other): 2116 if inverted: 2117 self, other = other, self 2118 return CompoundSelectQuery(self, operation, other) 2119 return method 2120 2121 2122class SelectQuery(Query): 2123 union_all = __add__ = __compound_select__('UNION ALL') 2124 union = __or__ = __compound_select__('UNION') 2125 intersect = __and__ = __compound_select__('INTERSECT') 2126 except_ = __sub__ = __compound_select__('EXCEPT') 2127 __radd__ = __compound_select__('UNION ALL', inverted=True) 2128 __ror__ = __compound_select__('UNION', inverted=True) 2129 __rand__ = __compound_select__('INTERSECT', inverted=True) 2130 __rsub__ = __compound_select__('EXCEPT', inverted=True) 2131 2132 def select_from(self, *columns): 2133 if not columns: 2134 raise ValueError('select_from() must specify one or more columns.') 2135 2136 query = (Select((self,), columns) 2137 .bind(self._database)) 2138 if getattr(self, 'model', None) is not None: 2139 # Bind to the sub-select's model type, if defined. 2140 query = query.objects(self.model) 2141 return query 2142 2143 2144class SelectBase(_HashableSource, Source, SelectQuery): 2145 def _get_hash(self): 2146 return hash((self.__class__, self._alias or id(self))) 2147 2148 def _execute(self, database): 2149 if self._cursor_wrapper is None: 2150 cursor = database.execute(self) 2151 self._cursor_wrapper = self._get_cursor_wrapper(cursor) 2152 return self._cursor_wrapper 2153 2154 @database_required 2155 def peek(self, database, n=1): 2156 rows = self.execute(database)[:n] 2157 if rows: 2158 return rows[0] if n == 1 else rows 2159 2160 @database_required 2161 def first(self, database, n=1): 2162 if self._limit != n: 2163 self._limit = n 2164 self._cursor_wrapper = None 2165 return self.peek(database, n=n) 2166 2167 @database_required 2168 def scalar(self, database, as_tuple=False): 2169 row = self.tuples().peek(database) 2170 return row[0] if row and not as_tuple else row 2171 2172 @database_required 2173 def count(self, database, clear_limit=False): 2174 clone = self.order_by().alias('_wrapped') 2175 if clear_limit: 2176 clone._limit = clone._offset = None 2177 try: 2178 if clone._having is None and clone._group_by is None and \ 2179 clone._windows is None and clone._distinct is None and \ 2180 clone._simple_distinct is not True: 2181 clone = clone.select(SQL('1')) 2182 except AttributeError: 2183 pass 2184 return Select([clone], [fn.COUNT(SQL('1'))]).scalar(database) 2185 2186 @database_required 2187 def exists(self, database): 2188 clone = self.columns(SQL('1')) 2189 clone._limit = 1 2190 clone._offset = None 2191 return bool(clone.scalar()) 2192 2193 @database_required 2194 def get(self, database): 2195 self._cursor_wrapper = None 2196 try: 2197 return self.execute(database)[0] 2198 except IndexError: 2199 pass 2200 2201 2202# QUERY IMPLEMENTATIONS. 2203 2204 2205class CompoundSelectQuery(SelectBase): 2206 def __init__(self, lhs, op, rhs): 2207 super(CompoundSelectQuery, self).__init__() 2208 self.lhs = lhs 2209 self.op = op 2210 self.rhs = rhs 2211 2212 @property 2213 def _returning(self): 2214 return self.lhs._returning 2215 2216 @database_required 2217 def exists(self, database): 2218 query = Select((self.limit(1),), (SQL('1'),)).bind(database) 2219 return bool(query.scalar()) 2220 2221 def _get_query_key(self): 2222 return (self.lhs.get_query_key(), self.rhs.get_query_key()) 2223 2224 def _wrap_parens(self, ctx, subq): 2225 csq_setting = ctx.state.compound_select_parentheses 2226 2227 if not csq_setting or csq_setting == CSQ_PARENTHESES_NEVER: 2228 return False 2229 elif csq_setting == CSQ_PARENTHESES_ALWAYS: 2230 return True 2231 elif csq_setting == CSQ_PARENTHESES_UNNESTED: 2232 if ctx.state.in_expr or ctx.state.in_function: 2233 # If this compound select query is being used inside an 2234 # expression, e.g., an IN or EXISTS(). 2235 return False 2236 2237 # If the query on the left or right is itself a compound select 2238 # query, then we do not apply parentheses. However, if it is a 2239 # regular SELECT query, we will apply parentheses. 2240 return not isinstance(subq, CompoundSelectQuery) 2241 2242 def __sql__(self, ctx): 2243 if ctx.scope == SCOPE_COLUMN: 2244 return self.apply_column(ctx) 2245 2246 # Call parent method to handle any CTEs. 2247 super(CompoundSelectQuery, self).__sql__(ctx) 2248 2249 outer_parens = ctx.subquery or (ctx.scope == SCOPE_SOURCE) 2250 with ctx(parentheses=outer_parens): 2251 # Should the left-hand query be wrapped in parentheses? 2252 lhs_parens = self._wrap_parens(ctx, self.lhs) 2253 with ctx.scope_normal(parentheses=lhs_parens, subquery=False): 2254 ctx.sql(self.lhs) 2255 ctx.literal(' %s ' % self.op) 2256 with ctx.push_alias(): 2257 # Should the right-hand query be wrapped in parentheses? 2258 rhs_parens = self._wrap_parens(ctx, self.rhs) 2259 with ctx.scope_normal(parentheses=rhs_parens, subquery=False): 2260 ctx.sql(self.rhs) 2261 2262 # Apply ORDER BY, LIMIT, OFFSET. We use the "values" scope so that 2263 # entity names are not fully-qualified. This is a bit of a hack, as 2264 # we're relying on the logic in Column.__sql__() to not fully 2265 # qualify column names. 2266 with ctx.scope_values(): 2267 self._apply_ordering(ctx) 2268 2269 return self.apply_alias(ctx) 2270 2271 2272class Select(SelectBase): 2273 def __init__(self, from_list=None, columns=None, group_by=None, 2274 having=None, distinct=None, windows=None, for_update=None, 2275 for_update_of=None, nowait=None, lateral=None, **kwargs): 2276 super(Select, self).__init__(**kwargs) 2277 self._from_list = (list(from_list) if isinstance(from_list, tuple) 2278 else from_list) or [] 2279 self._returning = columns 2280 self._group_by = group_by 2281 self._having = having 2282 self._windows = None 2283 self._for_update = for_update # XXX: consider reorganizing. 2284 self._for_update_of = for_update_of 2285 self._for_update_nowait = nowait 2286 self._lateral = lateral 2287 2288 self._distinct = self._simple_distinct = None 2289 if distinct: 2290 if isinstance(distinct, bool): 2291 self._simple_distinct = distinct 2292 else: 2293 self._distinct = distinct 2294 2295 self._cursor_wrapper = None 2296 2297 def clone(self): 2298 clone = super(Select, self).clone() 2299 if clone._from_list: 2300 clone._from_list = list(clone._from_list) 2301 return clone 2302 2303 @Node.copy 2304 def columns(self, *columns, **kwargs): 2305 self._returning = columns 2306 select = columns 2307 2308 @Node.copy 2309 def select_extend(self, *columns): 2310 self._returning = tuple(self._returning) + columns 2311 2312 @Node.copy 2313 def from_(self, *sources): 2314 self._from_list = list(sources) 2315 2316 @Node.copy 2317 def join(self, dest, join_type=JOIN.INNER, on=None): 2318 if not self._from_list: 2319 raise ValueError('No sources to join on.') 2320 item = self._from_list.pop() 2321 self._from_list.append(Join(item, dest, join_type, on)) 2322 2323 @Node.copy 2324 def group_by(self, *columns): 2325 grouping = [] 2326 for column in columns: 2327 if isinstance(column, Table): 2328 if not column._columns: 2329 raise ValueError('Cannot pass a table to group_by() that ' 2330 'does not have columns explicitly ' 2331 'declared.') 2332 grouping.extend([getattr(column, col_name) 2333 for col_name in column._columns]) 2334 else: 2335 grouping.append(column) 2336 self._group_by = grouping 2337 2338 def group_by_extend(self, *values): 2339 """@Node.copy used from group_by() call""" 2340 group_by = tuple(self._group_by or ()) + values 2341 return self.group_by(*group_by) 2342 2343 @Node.copy 2344 def having(self, *expressions): 2345 if self._having is not None: 2346 expressions = (self._having,) + expressions 2347 self._having = reduce(operator.and_, expressions) 2348 2349 @Node.copy 2350 def distinct(self, *columns): 2351 if len(columns) == 1 and (columns[0] is True or columns[0] is False): 2352 self._simple_distinct = columns[0] 2353 else: 2354 self._simple_distinct = False 2355 self._distinct = columns 2356 2357 @Node.copy 2358 def window(self, *windows): 2359 self._windows = windows if windows else None 2360 2361 @Node.copy 2362 def for_update(self, for_update=True, of=None, nowait=None): 2363 if not for_update and (of is not None or nowait): 2364 for_update = True 2365 self._for_update = for_update 2366 self._for_update_of = of 2367 self._for_update_nowait = nowait 2368 2369 @Node.copy 2370 def lateral(self, lateral=True): 2371 self._lateral = lateral 2372 2373 def _get_query_key(self): 2374 return self._alias 2375 2376 def __sql_selection__(self, ctx, is_subquery=False): 2377 return ctx.sql(CommaNodeList(self._returning)) 2378 2379 def __sql__(self, ctx): 2380 if ctx.scope == SCOPE_COLUMN: 2381 return self.apply_column(ctx) 2382 2383 if self._lateral and ctx.scope == SCOPE_SOURCE: 2384 ctx.literal('LATERAL ') 2385 2386 is_subquery = ctx.subquery 2387 state = { 2388 'converter': None, 2389 'in_function': False, 2390 'parentheses': is_subquery or (ctx.scope == SCOPE_SOURCE), 2391 'subquery': True, 2392 } 2393 if ctx.state.in_function and ctx.state.function_arg_count == 1: 2394 state['parentheses'] = False 2395 2396 with ctx.scope_normal(**state): 2397 # Defer calling parent SQL until here. This ensures that any CTEs 2398 # for this query will be properly nested if this query is a 2399 # sub-select or is used in an expression. See GH#1809 for example. 2400 super(Select, self).__sql__(ctx) 2401 2402 ctx.literal('SELECT ') 2403 if self._simple_distinct or self._distinct is not None: 2404 ctx.literal('DISTINCT ') 2405 if self._distinct: 2406 (ctx 2407 .literal('ON ') 2408 .sql(EnclosedNodeList(self._distinct)) 2409 .literal(' ')) 2410 2411 with ctx.scope_source(): 2412 ctx = self.__sql_selection__(ctx, is_subquery) 2413 2414 if self._from_list: 2415 with ctx.scope_source(parentheses=False): 2416 ctx.literal(' FROM ').sql(CommaNodeList(self._from_list)) 2417 2418 if self._where is not None: 2419 ctx.literal(' WHERE ').sql(self._where) 2420 2421 if self._group_by: 2422 ctx.literal(' GROUP BY ').sql(CommaNodeList(self._group_by)) 2423 2424 if self._having is not None: 2425 ctx.literal(' HAVING ').sql(self._having) 2426 2427 if self._windows is not None: 2428 ctx.literal(' WINDOW ') 2429 ctx.sql(CommaNodeList(self._windows)) 2430 2431 # Apply ORDER BY, LIMIT, OFFSET. 2432 self._apply_ordering(ctx) 2433 2434 if self._for_update: 2435 if not ctx.state.for_update: 2436 raise ValueError('FOR UPDATE specified but not supported ' 2437 'by database.') 2438 ctx.literal(' ') 2439 ctx.sql(ForUpdate(self._for_update, self._for_update_of, 2440 self._for_update_nowait)) 2441 2442 # If the subquery is inside a function -or- we are evaluating a 2443 # subquery on either side of an expression w/o an explicit alias, do 2444 # not generate an alias + AS clause. 2445 if ctx.state.in_function or (ctx.state.in_expr and 2446 self._alias is None): 2447 return ctx 2448 2449 return self.apply_alias(ctx) 2450 2451 2452class _WriteQuery(Query): 2453 def __init__(self, table, returning=None, **kwargs): 2454 self.table = table 2455 self._returning = returning 2456 self._return_cursor = True if returning else False 2457 super(_WriteQuery, self).__init__(**kwargs) 2458 2459 @Node.copy 2460 def returning(self, *returning): 2461 self._returning = returning 2462 self._return_cursor = True if returning else False 2463 2464 def apply_returning(self, ctx): 2465 if self._returning: 2466 with ctx.scope_source(): 2467 ctx.literal(' RETURNING ').sql(CommaNodeList(self._returning)) 2468 return ctx 2469 2470 def _execute(self, database): 2471 if self._returning: 2472 cursor = self.execute_returning(database) 2473 else: 2474 cursor = database.execute(self) 2475 return self.handle_result(database, cursor) 2476 2477 def execute_returning(self, database): 2478 if self._cursor_wrapper is None: 2479 cursor = database.execute(self) 2480 self._cursor_wrapper = self._get_cursor_wrapper(cursor) 2481 return self._cursor_wrapper 2482 2483 def handle_result(self, database, cursor): 2484 if self._return_cursor: 2485 return cursor 2486 return database.rows_affected(cursor) 2487 2488 def _set_table_alias(self, ctx): 2489 ctx.alias_manager[self.table] = self.table.__name__ 2490 2491 def __sql__(self, ctx): 2492 super(_WriteQuery, self).__sql__(ctx) 2493 # We explicitly set the table alias to the table's name, which ensures 2494 # that if a sub-select references a column on the outer table, we won't 2495 # assign it a new alias (e.g. t2) but will refer to it as table.column. 2496 self._set_table_alias(ctx) 2497 return ctx 2498 2499 2500class Update(_WriteQuery): 2501 def __init__(self, table, update=None, **kwargs): 2502 super(Update, self).__init__(table, **kwargs) 2503 self._update = update 2504 self._from = None 2505 2506 @Node.copy 2507 def from_(self, *sources): 2508 self._from = sources 2509 2510 def __sql__(self, ctx): 2511 super(Update, self).__sql__(ctx) 2512 2513 with ctx.scope_values(subquery=True): 2514 ctx.literal('UPDATE ') 2515 2516 expressions = [] 2517 for k, v in sorted(self._update.items(), key=ctx.column_sort_key): 2518 if not isinstance(v, Node): 2519 if isinstance(k, Field): 2520 v = k.to_value(v) 2521 else: 2522 v = Value(v, unpack=False) 2523 if not isinstance(v, Value): 2524 v = qualify_names(v) 2525 expressions.append(NodeList((k, SQL('='), v))) 2526 2527 (ctx 2528 .sql(self.table) 2529 .literal(' SET ') 2530 .sql(CommaNodeList(expressions))) 2531 2532 if self._from: 2533 with ctx.scope_source(parentheses=False): 2534 ctx.literal(' FROM ').sql(CommaNodeList(self._from)) 2535 2536 if self._where: 2537 with ctx.scope_normal(): 2538 ctx.literal(' WHERE ').sql(self._where) 2539 self._apply_ordering(ctx) 2540 return self.apply_returning(ctx) 2541 2542 2543class Insert(_WriteQuery): 2544 SIMPLE = 0 2545 QUERY = 1 2546 MULTI = 2 2547 class DefaultValuesException(Exception): pass 2548 2549 def __init__(self, table, insert=None, columns=None, on_conflict=None, 2550 **kwargs): 2551 super(Insert, self).__init__(table, **kwargs) 2552 self._insert = insert 2553 self._columns = columns 2554 self._on_conflict = on_conflict 2555 self._query_type = None 2556 2557 def where(self, *expressions): 2558 raise NotImplementedError('INSERT queries cannot have a WHERE clause.') 2559 2560 @Node.copy 2561 def on_conflict_ignore(self, ignore=True): 2562 self._on_conflict = OnConflict('IGNORE') if ignore else None 2563 2564 @Node.copy 2565 def on_conflict_replace(self, replace=True): 2566 self._on_conflict = OnConflict('REPLACE') if replace else None 2567 2568 @Node.copy 2569 def on_conflict(self, *args, **kwargs): 2570 self._on_conflict = (OnConflict(*args, **kwargs) if (args or kwargs) 2571 else None) 2572 2573 def _simple_insert(self, ctx): 2574 if not self._insert: 2575 raise self.DefaultValuesException('Error: no data to insert.') 2576 return self._generate_insert((self._insert,), ctx) 2577 2578 def get_default_data(self): 2579 return {} 2580 2581 def get_default_columns(self): 2582 if self.table._columns: 2583 return [getattr(self.table, col) for col in self.table._columns 2584 if col != self.table._primary_key] 2585 2586 def _generate_insert(self, insert, ctx): 2587 rows_iter = iter(insert) 2588 columns = self._columns 2589 2590 # Load and organize column defaults (if provided). 2591 defaults = self.get_default_data() 2592 2593 # First figure out what columns are being inserted (if they weren't 2594 # specified explicitly). Resulting columns are normalized and ordered. 2595 if not columns: 2596 try: 2597 row = next(rows_iter) 2598 except StopIteration: 2599 raise self.DefaultValuesException('Error: no rows to insert.') 2600 2601 if not isinstance(row, Mapping): 2602 columns = self.get_default_columns() 2603 if columns is None: 2604 raise ValueError('Bulk insert must specify columns.') 2605 else: 2606 # Infer column names from the dict of data being inserted. 2607 accum = [] 2608 for column in row: 2609 if isinstance(column, basestring): 2610 column = getattr(self.table, column) 2611 accum.append(column) 2612 2613 # Add any columns present in the default data that are not 2614 # accounted for by the dictionary of row data. 2615 column_set = set(accum) 2616 for col in (set(defaults) - column_set): 2617 accum.append(col) 2618 2619 columns = sorted(accum, key=lambda obj: obj.get_sort_key(ctx)) 2620 rows_iter = itertools.chain(iter((row,)), rows_iter) 2621 else: 2622 clean_columns = [] 2623 seen = set() 2624 for column in columns: 2625 if isinstance(column, basestring): 2626 column_obj = getattr(self.table, column) 2627 else: 2628 column_obj = column 2629 clean_columns.append(column_obj) 2630 seen.add(column_obj) 2631 2632 columns = clean_columns 2633 for col in sorted(defaults, key=lambda obj: obj.get_sort_key(ctx)): 2634 if col not in seen: 2635 columns.append(col) 2636 2637 nullable_columns = set() 2638 value_lookups = {} 2639 for column in columns: 2640 lookups = [column, column.name] 2641 if isinstance(column, Field): 2642 if column.name != column.column_name: 2643 lookups.append(column.column_name) 2644 if column.null: 2645 nullable_columns.add(column) 2646 value_lookups[column] = lookups 2647 2648 ctx.sql(EnclosedNodeList(columns)).literal(' VALUES ') 2649 columns_converters = [ 2650 (column, column.db_value if isinstance(column, Field) else None) 2651 for column in columns] 2652 2653 all_values = [] 2654 for row in rows_iter: 2655 values = [] 2656 is_dict = isinstance(row, Mapping) 2657 for i, (column, converter) in enumerate(columns_converters): 2658 try: 2659 if is_dict: 2660 # The logic is a bit convoluted, but in order to be 2661 # flexible in what we accept (dict keyed by 2662 # column/field, field name, or underlying column name), 2663 # we try accessing the row data dict using each 2664 # possible key. If no match is found, throw an error. 2665 for lookup in value_lookups[column]: 2666 try: 2667 val = row[lookup] 2668 except KeyError: pass 2669 else: break 2670 else: 2671 raise KeyError 2672 else: 2673 val = row[i] 2674 except (KeyError, IndexError): 2675 if column in defaults: 2676 val = defaults[column] 2677 if callable_(val): 2678 val = val() 2679 elif column in nullable_columns: 2680 val = None 2681 else: 2682 raise ValueError('Missing value for %s.' % column.name) 2683 2684 if not isinstance(val, Node): 2685 val = Value(val, converter=converter, unpack=False) 2686 values.append(val) 2687 2688 all_values.append(EnclosedNodeList(values)) 2689 2690 if not all_values: 2691 raise self.DefaultValuesException('Error: no data to insert.') 2692 2693 with ctx.scope_values(subquery=True): 2694 return ctx.sql(CommaNodeList(all_values)) 2695 2696 def _query_insert(self, ctx): 2697 return (ctx 2698 .sql(EnclosedNodeList(self._columns)) 2699 .literal(' ') 2700 .sql(self._insert)) 2701 2702 def _default_values(self, ctx): 2703 if not self._database: 2704 return ctx.literal('DEFAULT VALUES') 2705 return self._database.default_values_insert(ctx) 2706 2707 def __sql__(self, ctx): 2708 super(Insert, self).__sql__(ctx) 2709 with ctx.scope_values(): 2710 stmt = None 2711 if self._on_conflict is not None: 2712 stmt = self._on_conflict.get_conflict_statement(ctx, self) 2713 2714 (ctx 2715 .sql(stmt or SQL('INSERT')) 2716 .literal(' INTO ') 2717 .sql(self.table) 2718 .literal(' ')) 2719 2720 if isinstance(self._insert, Mapping) and not self._columns: 2721 try: 2722 self._simple_insert(ctx) 2723 except self.DefaultValuesException: 2724 self._default_values(ctx) 2725 self._query_type = Insert.SIMPLE 2726 elif isinstance(self._insert, (SelectQuery, SQL)): 2727 self._query_insert(ctx) 2728 self._query_type = Insert.QUERY 2729 else: 2730 self._generate_insert(self._insert, ctx) 2731 self._query_type = Insert.MULTI 2732 2733 if self._on_conflict is not None: 2734 update = self._on_conflict.get_conflict_update(ctx, self) 2735 if update is not None: 2736 ctx.literal(' ').sql(update) 2737 2738 return self.apply_returning(ctx) 2739 2740 def _execute(self, database): 2741 if self._returning is None and database.returning_clause \ 2742 and self.table._primary_key: 2743 self._returning = (self.table._primary_key,) 2744 try: 2745 return super(Insert, self)._execute(database) 2746 except self.DefaultValuesException: 2747 pass 2748 2749 def handle_result(self, database, cursor): 2750 if self._return_cursor: 2751 return cursor 2752 if self._query_type != Insert.SIMPLE and not self._returning: 2753 return database.rows_affected(cursor) 2754 return database.last_insert_id(cursor, self._query_type) 2755 2756 2757class Delete(_WriteQuery): 2758 def __sql__(self, ctx): 2759 super(Delete, self).__sql__(ctx) 2760 2761 with ctx.scope_values(subquery=True): 2762 ctx.literal('DELETE FROM ').sql(self.table) 2763 if self._where is not None: 2764 with ctx.scope_normal(): 2765 ctx.literal(' WHERE ').sql(self._where) 2766 2767 self._apply_ordering(ctx) 2768 return self.apply_returning(ctx) 2769 2770 2771class Index(Node): 2772 def __init__(self, name, table, expressions, unique=False, safe=False, 2773 where=None, using=None): 2774 self._name = name 2775 self._table = Entity(table) if not isinstance(table, Table) else table 2776 self._expressions = expressions 2777 self._where = where 2778 self._unique = unique 2779 self._safe = safe 2780 self._using = using 2781 2782 @Node.copy 2783 def safe(self, _safe=True): 2784 self._safe = _safe 2785 2786 @Node.copy 2787 def where(self, *expressions): 2788 if self._where is not None: 2789 expressions = (self._where,) + expressions 2790 self._where = reduce(operator.and_, expressions) 2791 2792 @Node.copy 2793 def using(self, _using=None): 2794 self._using = _using 2795 2796 def __sql__(self, ctx): 2797 statement = 'CREATE UNIQUE INDEX ' if self._unique else 'CREATE INDEX ' 2798 with ctx.scope_values(subquery=True): 2799 ctx.literal(statement) 2800 if self._safe: 2801 ctx.literal('IF NOT EXISTS ') 2802 2803 # Sqlite uses CREATE INDEX <schema>.<name> ON <table>, whereas most 2804 # others use: CREATE INDEX <name> ON <schema>.<table>. 2805 if ctx.state.index_schema_prefix and \ 2806 isinstance(self._table, Table) and self._table._schema: 2807 index_name = Entity(self._table._schema, self._name) 2808 table_name = Entity(self._table.__name__) 2809 else: 2810 index_name = Entity(self._name) 2811 table_name = self._table 2812 2813 ctx.sql(index_name) 2814 if self._using is not None and \ 2815 ctx.state.index_using_precedes_table: 2816 ctx.literal(' USING %s' % self._using) # MySQL style. 2817 2818 (ctx 2819 .literal(' ON ') 2820 .sql(table_name) 2821 .literal(' ')) 2822 2823 if self._using is not None and not \ 2824 ctx.state.index_using_precedes_table: 2825 ctx.literal('USING %s ' % self._using) # Postgres/default. 2826 2827 ctx.sql(EnclosedNodeList([ 2828 SQL(expr) if isinstance(expr, basestring) else expr 2829 for expr in self._expressions])) 2830 if self._where is not None: 2831 ctx.literal(' WHERE ').sql(self._where) 2832 2833 return ctx 2834 2835 2836class ModelIndex(Index): 2837 def __init__(self, model, fields, unique=False, safe=True, where=None, 2838 using=None, name=None): 2839 self._model = model 2840 if name is None: 2841 name = self._generate_name_from_fields(model, fields) 2842 if using is None: 2843 for field in fields: 2844 if isinstance(field, Field) and hasattr(field, 'index_type'): 2845 using = field.index_type 2846 super(ModelIndex, self).__init__( 2847 name=name, 2848 table=model._meta.table, 2849 expressions=fields, 2850 unique=unique, 2851 safe=safe, 2852 where=where, 2853 using=using) 2854 2855 def _generate_name_from_fields(self, model, fields): 2856 accum = [] 2857 for field in fields: 2858 if isinstance(field, basestring): 2859 accum.append(field.split()[0]) 2860 else: 2861 if isinstance(field, Node) and not isinstance(field, Field): 2862 field = field.unwrap() 2863 if isinstance(field, Field): 2864 accum.append(field.column_name) 2865 2866 if not accum: 2867 raise ValueError('Unable to generate a name for the index, please ' 2868 'explicitly specify a name.') 2869 2870 clean_field_names = re.sub(r'[^\w]+', '', '_'.join(accum)) 2871 meta = model._meta 2872 prefix = meta.name if meta.legacy_table_names else meta.table_name 2873 return _truncate_constraint_name('_'.join((prefix, clean_field_names))) 2874 2875 2876def _truncate_constraint_name(constraint, maxlen=64): 2877 if len(constraint) > maxlen: 2878 name_hash = hashlib.md5(constraint.encode('utf-8')).hexdigest() 2879 constraint = '%s_%s' % (constraint[:(maxlen - 8)], name_hash[:7]) 2880 return constraint 2881 2882 2883# DB-API 2.0 EXCEPTIONS. 2884 2885 2886class PeeweeException(Exception): 2887 def __init__(self, *args): 2888 if args and isinstance(args[0], Exception): 2889 self.orig, args = args[0], args[1:] 2890 super(PeeweeException, self).__init__(*args) 2891class ImproperlyConfigured(PeeweeException): pass 2892class DatabaseError(PeeweeException): pass 2893class DataError(DatabaseError): pass 2894class IntegrityError(DatabaseError): pass 2895class InterfaceError(PeeweeException): pass 2896class InternalError(DatabaseError): pass 2897class NotSupportedError(DatabaseError): pass 2898class OperationalError(DatabaseError): pass 2899class ProgrammingError(DatabaseError): pass 2900 2901 2902class ExceptionWrapper(object): 2903 __slots__ = ('exceptions',) 2904 def __init__(self, exceptions): 2905 self.exceptions = exceptions 2906 def __enter__(self): pass 2907 def __exit__(self, exc_type, exc_value, traceback): 2908 if exc_type is None: 2909 return 2910 # psycopg2.8 shits out a million cute error types. Try to catch em all. 2911 if pg_errors is not None and exc_type.__name__ not in self.exceptions \ 2912 and issubclass(exc_type, pg_errors.Error): 2913 exc_type = exc_type.__bases__[0] 2914 if exc_type.__name__ in self.exceptions: 2915 new_type = self.exceptions[exc_type.__name__] 2916 exc_args = exc_value.args 2917 reraise(new_type, new_type(exc_value, *exc_args), traceback) 2918 2919 2920EXCEPTIONS = { 2921 'ConstraintError': IntegrityError, 2922 'DatabaseError': DatabaseError, 2923 'DataError': DataError, 2924 'IntegrityError': IntegrityError, 2925 'InterfaceError': InterfaceError, 2926 'InternalError': InternalError, 2927 'NotSupportedError': NotSupportedError, 2928 'OperationalError': OperationalError, 2929 'ProgrammingError': ProgrammingError, 2930 'TransactionRollbackError': OperationalError} 2931 2932__exception_wrapper__ = ExceptionWrapper(EXCEPTIONS) 2933 2934 2935# DATABASE INTERFACE AND CONNECTION MANAGEMENT. 2936 2937 2938IndexMetadata = collections.namedtuple( 2939 'IndexMetadata', 2940 ('name', 'sql', 'columns', 'unique', 'table')) 2941ColumnMetadata = collections.namedtuple( 2942 'ColumnMetadata', 2943 ('name', 'data_type', 'null', 'primary_key', 'table', 'default')) 2944ForeignKeyMetadata = collections.namedtuple( 2945 'ForeignKeyMetadata', 2946 ('column', 'dest_table', 'dest_column', 'table')) 2947ViewMetadata = collections.namedtuple('ViewMetadata', ('name', 'sql')) 2948 2949 2950class _ConnectionState(object): 2951 def __init__(self, **kwargs): 2952 super(_ConnectionState, self).__init__(**kwargs) 2953 self.reset() 2954 2955 def reset(self): 2956 self.closed = True 2957 self.conn = None 2958 self.ctx = [] 2959 self.transactions = [] 2960 2961 def set_connection(self, conn): 2962 self.conn = conn 2963 self.closed = False 2964 self.ctx = [] 2965 self.transactions = [] 2966 2967 2968class _ConnectionLocal(_ConnectionState, threading.local): pass 2969class _NoopLock(object): 2970 __slots__ = () 2971 def __enter__(self): return self 2972 def __exit__(self, exc_type, exc_val, exc_tb): pass 2973 2974 2975class ConnectionContext(_callable_context_manager): 2976 __slots__ = ('db',) 2977 def __init__(self, db): self.db = db 2978 def __enter__(self): 2979 if self.db.is_closed(): 2980 self.db.connect() 2981 def __exit__(self, exc_type, exc_val, exc_tb): self.db.close() 2982 2983 2984class Database(_callable_context_manager): 2985 context_class = Context 2986 field_types = {} 2987 operations = {} 2988 param = '?' 2989 quote = '""' 2990 server_version = None 2991 2992 # Feature toggles. 2993 commit_select = False 2994 compound_select_parentheses = CSQ_PARENTHESES_NEVER 2995 for_update = False 2996 index_schema_prefix = False 2997 index_using_precedes_table = False 2998 limit_max = None 2999 nulls_ordering = False 3000 returning_clause = False 3001 safe_create_index = True 3002 safe_drop_index = True 3003 sequences = False 3004 truncate_table = True 3005 3006 def __init__(self, database, thread_safe=True, autorollback=False, 3007 field_types=None, operations=None, autocommit=None, 3008 autoconnect=True, **kwargs): 3009 self._field_types = merge_dict(FIELD, self.field_types) 3010 self._operations = merge_dict(OP, self.operations) 3011 if field_types: 3012 self._field_types.update(field_types) 3013 if operations: 3014 self._operations.update(operations) 3015 3016 self.autoconnect = autoconnect 3017 self.autorollback = autorollback 3018 self.thread_safe = thread_safe 3019 if thread_safe: 3020 self._state = _ConnectionLocal() 3021 self._lock = threading.RLock() 3022 else: 3023 self._state = _ConnectionState() 3024 self._lock = _NoopLock() 3025 3026 if autocommit is not None: 3027 __deprecated__('Peewee no longer uses the "autocommit" option, as ' 3028 'the semantics now require it to always be True. ' 3029 'Because some database-drivers also use the ' 3030 '"autocommit" parameter, you are receiving a ' 3031 'warning so you may update your code and remove ' 3032 'the parameter, as in the future, specifying ' 3033 'autocommit could impact the behavior of the ' 3034 'database driver you are using.') 3035 3036 self.connect_params = {} 3037 self.init(database, **kwargs) 3038 3039 def init(self, database, **kwargs): 3040 if not self.is_closed(): 3041 self.close() 3042 self.database = database 3043 self.connect_params.update(kwargs) 3044 self.deferred = not bool(database) 3045 3046 def __enter__(self): 3047 if self.is_closed(): 3048 self.connect() 3049 ctx = self.atomic() 3050 self._state.ctx.append(ctx) 3051 ctx.__enter__() 3052 return self 3053 3054 def __exit__(self, exc_type, exc_val, exc_tb): 3055 ctx = self._state.ctx.pop() 3056 try: 3057 ctx.__exit__(exc_type, exc_val, exc_tb) 3058 finally: 3059 if not self._state.ctx: 3060 self.close() 3061 3062 def connection_context(self): 3063 return ConnectionContext(self) 3064 3065 def _connect(self): 3066 raise NotImplementedError 3067 3068 def connect(self, reuse_if_open=False): 3069 with self._lock: 3070 if self.deferred: 3071 raise InterfaceError('Error, database must be initialized ' 3072 'before opening a connection.') 3073 if not self._state.closed: 3074 if reuse_if_open: 3075 return False 3076 raise OperationalError('Connection already opened.') 3077 3078 self._state.reset() 3079 with __exception_wrapper__: 3080 self._state.set_connection(self._connect()) 3081 if self.server_version is None: 3082 self._set_server_version(self._state.conn) 3083 self._initialize_connection(self._state.conn) 3084 return True 3085 3086 def _initialize_connection(self, conn): 3087 pass 3088 3089 def _set_server_version(self, conn): 3090 self.server_version = 0 3091 3092 def close(self): 3093 with self._lock: 3094 if self.deferred: 3095 raise InterfaceError('Error, database must be initialized ' 3096 'before opening a connection.') 3097 if self.in_transaction(): 3098 raise OperationalError('Attempting to close database while ' 3099 'transaction is open.') 3100 is_open = not self._state.closed 3101 try: 3102 if is_open: 3103 with __exception_wrapper__: 3104 self._close(self._state.conn) 3105 finally: 3106 self._state.reset() 3107 return is_open 3108 3109 def _close(self, conn): 3110 conn.close() 3111 3112 def is_closed(self): 3113 return self._state.closed 3114 3115 def is_connection_usable(self): 3116 return not self._state.closed 3117 3118 def connection(self): 3119 if self.is_closed(): 3120 self.connect() 3121 return self._state.conn 3122 3123 def cursor(self, commit=None): 3124 if self.is_closed(): 3125 if self.autoconnect: 3126 self.connect() 3127 else: 3128 raise InterfaceError('Error, database connection not opened.') 3129 return self._state.conn.cursor() 3130 3131 def execute_sql(self, sql, params=None, commit=SENTINEL): 3132 logger.debug((sql, params)) 3133 if commit is SENTINEL: 3134 if self.in_transaction(): 3135 commit = False 3136 elif self.commit_select: 3137 commit = True 3138 else: 3139 commit = not sql[:6].lower().startswith('select') 3140 3141 with __exception_wrapper__: 3142 cursor = self.cursor(commit) 3143 try: 3144 cursor.execute(sql, params or ()) 3145 except Exception: 3146 if self.autorollback and not self.in_transaction(): 3147 self.rollback() 3148 raise 3149 else: 3150 if commit and not self.in_transaction(): 3151 self.commit() 3152 return cursor 3153 3154 def execute(self, query, commit=SENTINEL, **context_options): 3155 ctx = self.get_sql_context(**context_options) 3156 sql, params = ctx.sql(query).query() 3157 return self.execute_sql(sql, params, commit=commit) 3158 3159 def get_context_options(self): 3160 return { 3161 'field_types': self._field_types, 3162 'operations': self._operations, 3163 'param': self.param, 3164 'quote': self.quote, 3165 'compound_select_parentheses': self.compound_select_parentheses, 3166 'conflict_statement': self.conflict_statement, 3167 'conflict_update': self.conflict_update, 3168 'for_update': self.for_update, 3169 'index_schema_prefix': self.index_schema_prefix, 3170 'index_using_precedes_table': self.index_using_precedes_table, 3171 'limit_max': self.limit_max, 3172 'nulls_ordering': self.nulls_ordering, 3173 } 3174 3175 def get_sql_context(self, **context_options): 3176 context = self.get_context_options() 3177 if context_options: 3178 context.update(context_options) 3179 return self.context_class(**context) 3180 3181 def conflict_statement(self, on_conflict, query): 3182 raise NotImplementedError 3183 3184 def conflict_update(self, on_conflict, query): 3185 raise NotImplementedError 3186 3187 def _build_on_conflict_update(self, on_conflict, query): 3188 if on_conflict._conflict_target: 3189 stmt = SQL('ON CONFLICT') 3190 target = EnclosedNodeList([ 3191 Entity(col) if isinstance(col, basestring) else col 3192 for col in on_conflict._conflict_target]) 3193 if on_conflict._conflict_where is not None: 3194 target = NodeList([target, SQL('WHERE'), 3195 on_conflict._conflict_where]) 3196 else: 3197 stmt = SQL('ON CONFLICT ON CONSTRAINT') 3198 target = on_conflict._conflict_constraint 3199 if isinstance(target, basestring): 3200 target = Entity(target) 3201 3202 updates = [] 3203 if on_conflict._preserve: 3204 for column in on_conflict._preserve: 3205 excluded = NodeList((SQL('EXCLUDED'), ensure_entity(column)), 3206 glue='.') 3207 expression = NodeList((ensure_entity(column), SQL('='), 3208 excluded)) 3209 updates.append(expression) 3210 3211 if on_conflict._update: 3212 for k, v in on_conflict._update.items(): 3213 if not isinstance(v, Node): 3214 # Attempt to resolve string field-names to their respective 3215 # field object, to apply data-type conversions. 3216 if isinstance(k, basestring): 3217 k = getattr(query.table, k) 3218 if isinstance(k, Field): 3219 v = k.to_value(v) 3220 else: 3221 v = Value(v, unpack=False) 3222 else: 3223 v = QualifiedNames(v) 3224 updates.append(NodeList((ensure_entity(k), SQL('='), v))) 3225 3226 parts = [stmt, target, SQL('DO UPDATE SET'), CommaNodeList(updates)] 3227 if on_conflict._where: 3228 parts.extend((SQL('WHERE'), QualifiedNames(on_conflict._where))) 3229 3230 return NodeList(parts) 3231 3232 def last_insert_id(self, cursor, query_type=None): 3233 return cursor.lastrowid 3234 3235 def rows_affected(self, cursor): 3236 return cursor.rowcount 3237 3238 def default_values_insert(self, ctx): 3239 return ctx.literal('DEFAULT VALUES') 3240 3241 def session_start(self): 3242 with self._lock: 3243 return self.transaction().__enter__() 3244 3245 def session_commit(self): 3246 with self._lock: 3247 try: 3248 txn = self.pop_transaction() 3249 except IndexError: 3250 return False 3251 txn.commit(begin=self.in_transaction()) 3252 return True 3253 3254 def session_rollback(self): 3255 with self._lock: 3256 try: 3257 txn = self.pop_transaction() 3258 except IndexError: 3259 return False 3260 txn.rollback(begin=self.in_transaction()) 3261 return True 3262 3263 def in_transaction(self): 3264 return bool(self._state.transactions) 3265 3266 def push_transaction(self, transaction): 3267 self._state.transactions.append(transaction) 3268 3269 def pop_transaction(self): 3270 return self._state.transactions.pop() 3271 3272 def transaction_depth(self): 3273 return len(self._state.transactions) 3274 3275 def top_transaction(self): 3276 if self._state.transactions: 3277 return self._state.transactions[-1] 3278 3279 def atomic(self, *args, **kwargs): 3280 return _atomic(self, *args, **kwargs) 3281 3282 def manual_commit(self): 3283 return _manual(self) 3284 3285 def transaction(self, *args, **kwargs): 3286 return _transaction(self, *args, **kwargs) 3287 3288 def savepoint(self): 3289 return _savepoint(self) 3290 3291 def begin(self): 3292 if self.is_closed(): 3293 self.connect() 3294 3295 def commit(self): 3296 with __exception_wrapper__: 3297 return self._state.conn.commit() 3298 3299 def rollback(self): 3300 with __exception_wrapper__: 3301 return self._state.conn.rollback() 3302 3303 def batch_commit(self, it, n): 3304 for group in chunked(it, n): 3305 with self.atomic(): 3306 for obj in group: 3307 yield obj 3308 3309 def table_exists(self, table_name, schema=None): 3310 return table_name in self.get_tables(schema=schema) 3311 3312 def get_tables(self, schema=None): 3313 raise NotImplementedError 3314 3315 def get_indexes(self, table, schema=None): 3316 raise NotImplementedError 3317 3318 def get_columns(self, table, schema=None): 3319 raise NotImplementedError 3320 3321 def get_primary_keys(self, table, schema=None): 3322 raise NotImplementedError 3323 3324 def get_foreign_keys(self, table, schema=None): 3325 raise NotImplementedError 3326 3327 def sequence_exists(self, seq): 3328 raise NotImplementedError 3329 3330 def create_tables(self, models, **options): 3331 for model in sort_models(models): 3332 model.create_table(**options) 3333 3334 def drop_tables(self, models, **kwargs): 3335 for model in reversed(sort_models(models)): 3336 model.drop_table(**kwargs) 3337 3338 def extract_date(self, date_part, date_field): 3339 raise NotImplementedError 3340 3341 def truncate_date(self, date_part, date_field): 3342 raise NotImplementedError 3343 3344 def to_timestamp(self, date_field): 3345 raise NotImplementedError 3346 3347 def from_timestamp(self, date_field): 3348 raise NotImplementedError 3349 3350 def random(self): 3351 return fn.random() 3352 3353 def bind(self, models, bind_refs=True, bind_backrefs=True): 3354 for model in models: 3355 model.bind(self, bind_refs=bind_refs, bind_backrefs=bind_backrefs) 3356 3357 def bind_ctx(self, models, bind_refs=True, bind_backrefs=True): 3358 return _BoundModelsContext(models, self, bind_refs, bind_backrefs) 3359 3360 def get_noop_select(self, ctx): 3361 return ctx.sql(Select().columns(SQL('0')).where(SQL('0'))) 3362 3363 3364def __pragma__(name): 3365 def __get__(self): 3366 return self.pragma(name) 3367 def __set__(self, value): 3368 return self.pragma(name, value) 3369 return property(__get__, __set__) 3370 3371 3372class SqliteDatabase(Database): 3373 field_types = { 3374 'BIGAUTO': FIELD.AUTO, 3375 'BIGINT': FIELD.INT, 3376 'BOOL': FIELD.INT, 3377 'DOUBLE': FIELD.FLOAT, 3378 'SMALLINT': FIELD.INT, 3379 'UUID': FIELD.TEXT} 3380 operations = { 3381 'LIKE': 'GLOB', 3382 'ILIKE': 'LIKE'} 3383 index_schema_prefix = True 3384 limit_max = -1 3385 server_version = __sqlite_version__ 3386 truncate_table = False 3387 3388 def __init__(self, database, *args, **kwargs): 3389 self._pragmas = kwargs.pop('pragmas', ()) 3390 super(SqliteDatabase, self).__init__(database, *args, **kwargs) 3391 self._aggregates = {} 3392 self._collations = {} 3393 self._functions = {} 3394 self._window_functions = {} 3395 self._table_functions = [] 3396 self._extensions = set() 3397 self._attached = {} 3398 self.register_function(_sqlite_date_part, 'date_part', 2) 3399 self.register_function(_sqlite_date_trunc, 'date_trunc', 2) 3400 self.nulls_ordering = self.server_version >= (3, 30, 0) 3401 3402 def init(self, database, pragmas=None, timeout=5, **kwargs): 3403 if pragmas is not None: 3404 self._pragmas = pragmas 3405 if isinstance(self._pragmas, dict): 3406 self._pragmas = list(self._pragmas.items()) 3407 self._timeout = timeout 3408 super(SqliteDatabase, self).init(database, **kwargs) 3409 3410 def _set_server_version(self, conn): 3411 pass 3412 3413 def _connect(self): 3414 if sqlite3 is None: 3415 raise ImproperlyConfigured('SQLite driver not installed!') 3416 conn = sqlite3.connect(self.database, timeout=self._timeout, 3417 isolation_level=None, **self.connect_params) 3418 try: 3419 self._add_conn_hooks(conn) 3420 except: 3421 conn.close() 3422 raise 3423 return conn 3424 3425 def _add_conn_hooks(self, conn): 3426 if self._attached: 3427 self._attach_databases(conn) 3428 if self._pragmas: 3429 self._set_pragmas(conn) 3430 self._load_aggregates(conn) 3431 self._load_collations(conn) 3432 self._load_functions(conn) 3433 if self.server_version >= (3, 25, 0): 3434 self._load_window_functions(conn) 3435 if self._table_functions: 3436 for table_function in self._table_functions: 3437 table_function.register(conn) 3438 if self._extensions: 3439 self._load_extensions(conn) 3440 3441 def _set_pragmas(self, conn): 3442 cursor = conn.cursor() 3443 for pragma, value in self._pragmas: 3444 cursor.execute('PRAGMA %s = %s;' % (pragma, value)) 3445 cursor.close() 3446 3447 def _attach_databases(self, conn): 3448 cursor = conn.cursor() 3449 for name, db in self._attached.items(): 3450 cursor.execute('ATTACH DATABASE "%s" AS "%s"' % (db, name)) 3451 cursor.close() 3452 3453 def pragma(self, key, value=SENTINEL, permanent=False, schema=None): 3454 if schema is not None: 3455 key = '"%s".%s' % (schema, key) 3456 sql = 'PRAGMA %s' % key 3457 if value is not SENTINEL: 3458 sql += ' = %s' % (value or 0) 3459 if permanent: 3460 pragmas = dict(self._pragmas or ()) 3461 pragmas[key] = value 3462 self._pragmas = list(pragmas.items()) 3463 elif permanent: 3464 raise ValueError('Cannot specify a permanent pragma without value') 3465 row = self.execute_sql(sql).fetchone() 3466 if row: 3467 return row[0] 3468 3469 cache_size = __pragma__('cache_size') 3470 foreign_keys = __pragma__('foreign_keys') 3471 journal_mode = __pragma__('journal_mode') 3472 journal_size_limit = __pragma__('journal_size_limit') 3473 mmap_size = __pragma__('mmap_size') 3474 page_size = __pragma__('page_size') 3475 read_uncommitted = __pragma__('read_uncommitted') 3476 synchronous = __pragma__('synchronous') 3477 wal_autocheckpoint = __pragma__('wal_autocheckpoint') 3478 3479 @property 3480 def timeout(self): 3481 return self._timeout 3482 3483 @timeout.setter 3484 def timeout(self, seconds): 3485 if self._timeout == seconds: 3486 return 3487 3488 self._timeout = seconds 3489 if not self.is_closed(): 3490 # PySQLite multiplies user timeout by 1000, but the unit of the 3491 # timeout PRAGMA is actually milliseconds. 3492 self.execute_sql('PRAGMA busy_timeout=%d;' % (seconds * 1000)) 3493 3494 def _load_aggregates(self, conn): 3495 for name, (klass, num_params) in self._aggregates.items(): 3496 conn.create_aggregate(name, num_params, klass) 3497 3498 def _load_collations(self, conn): 3499 for name, fn in self._collations.items(): 3500 conn.create_collation(name, fn) 3501 3502 def _load_functions(self, conn): 3503 for name, (fn, num_params) in self._functions.items(): 3504 conn.create_function(name, num_params, fn) 3505 3506 def _load_window_functions(self, conn): 3507 for name, (klass, num_params) in self._window_functions.items(): 3508 conn.create_window_function(name, num_params, klass) 3509 3510 def register_aggregate(self, klass, name=None, num_params=-1): 3511 self._aggregates[name or klass.__name__.lower()] = (klass, num_params) 3512 if not self.is_closed(): 3513 self._load_aggregates(self.connection()) 3514 3515 def aggregate(self, name=None, num_params=-1): 3516 def decorator(klass): 3517 self.register_aggregate(klass, name, num_params) 3518 return klass 3519 return decorator 3520 3521 def register_collation(self, fn, name=None): 3522 name = name or fn.__name__ 3523 def _collation(*args): 3524 expressions = args + (SQL('collate %s' % name),) 3525 return NodeList(expressions) 3526 fn.collation = _collation 3527 self._collations[name] = fn 3528 if not self.is_closed(): 3529 self._load_collations(self.connection()) 3530 3531 def collation(self, name=None): 3532 def decorator(fn): 3533 self.register_collation(fn, name) 3534 return fn 3535 return decorator 3536 3537 def register_function(self, fn, name=None, num_params=-1): 3538 self._functions[name or fn.__name__] = (fn, num_params) 3539 if not self.is_closed(): 3540 self._load_functions(self.connection()) 3541 3542 def func(self, name=None, num_params=-1): 3543 def decorator(fn): 3544 self.register_function(fn, name, num_params) 3545 return fn 3546 return decorator 3547 3548 def register_window_function(self, klass, name=None, num_params=-1): 3549 name = name or klass.__name__.lower() 3550 self._window_functions[name] = (klass, num_params) 3551 if not self.is_closed(): 3552 self._load_window_functions(self.connection()) 3553 3554 def window_function(self, name=None, num_params=-1): 3555 def decorator(klass): 3556 self.register_window_function(klass, name, num_params) 3557 return klass 3558 return decorator 3559 3560 def register_table_function(self, klass, name=None): 3561 if name is not None: 3562 klass.name = name 3563 self._table_functions.append(klass) 3564 if not self.is_closed(): 3565 klass.register(self.connection()) 3566 3567 def table_function(self, name=None): 3568 def decorator(klass): 3569 self.register_table_function(klass, name) 3570 return klass 3571 return decorator 3572 3573 def unregister_aggregate(self, name): 3574 del(self._aggregates[name]) 3575 3576 def unregister_collation(self, name): 3577 del(self._collations[name]) 3578 3579 def unregister_function(self, name): 3580 del(self._functions[name]) 3581 3582 def unregister_window_function(self, name): 3583 del(self._window_functions[name]) 3584 3585 def unregister_table_function(self, name): 3586 for idx, klass in enumerate(self._table_functions): 3587 if klass.name == name: 3588 break 3589 else: 3590 return False 3591 self._table_functions.pop(idx) 3592 return True 3593 3594 def _load_extensions(self, conn): 3595 conn.enable_load_extension(True) 3596 for extension in self._extensions: 3597 conn.load_extension(extension) 3598 3599 def load_extension(self, extension): 3600 self._extensions.add(extension) 3601 if not self.is_closed(): 3602 conn = self.connection() 3603 conn.enable_load_extension(True) 3604 conn.load_extension(extension) 3605 3606 def unload_extension(self, extension): 3607 self._extensions.remove(extension) 3608 3609 def attach(self, filename, name): 3610 if name in self._attached: 3611 if self._attached[name] == filename: 3612 return False 3613 raise OperationalError('schema "%s" already attached.' % name) 3614 3615 self._attached[name] = filename 3616 if not self.is_closed(): 3617 self.execute_sql('ATTACH DATABASE "%s" AS "%s"' % (filename, name)) 3618 return True 3619 3620 def detach(self, name): 3621 if name not in self._attached: 3622 return False 3623 3624 del self._attached[name] 3625 if not self.is_closed(): 3626 self.execute_sql('DETACH DATABASE "%s"' % name) 3627 return True 3628 3629 def begin(self, lock_type=None): 3630 statement = 'BEGIN %s' % lock_type if lock_type else 'BEGIN' 3631 self.execute_sql(statement, commit=False) 3632 3633 def get_tables(self, schema=None): 3634 schema = schema or 'main' 3635 cursor = self.execute_sql('SELECT name FROM "%s".sqlite_master WHERE ' 3636 'type=? ORDER BY name' % schema, ('table',)) 3637 return [row for row, in cursor.fetchall()] 3638 3639 def get_views(self, schema=None): 3640 sql = ('SELECT name, sql FROM "%s".sqlite_master WHERE type=? ' 3641 'ORDER BY name') % (schema or 'main') 3642 return [ViewMetadata(*row) for row in self.execute_sql(sql, ('view',))] 3643 3644 def get_indexes(self, table, schema=None): 3645 schema = schema or 'main' 3646 query = ('SELECT name, sql FROM "%s".sqlite_master ' 3647 'WHERE tbl_name = ? AND type = ? ORDER BY name') % schema 3648 cursor = self.execute_sql(query, (table, 'index')) 3649 index_to_sql = dict(cursor.fetchall()) 3650 3651 # Determine which indexes have a unique constraint. 3652 unique_indexes = set() 3653 cursor = self.execute_sql('PRAGMA "%s".index_list("%s")' % 3654 (schema, table)) 3655 for row in cursor.fetchall(): 3656 name = row[1] 3657 is_unique = int(row[2]) == 1 3658 if is_unique: 3659 unique_indexes.add(name) 3660 3661 # Retrieve the indexed columns. 3662 index_columns = {} 3663 for index_name in sorted(index_to_sql): 3664 cursor = self.execute_sql('PRAGMA "%s".index_info("%s")' % 3665 (schema, index_name)) 3666 index_columns[index_name] = [row[2] for row in cursor.fetchall()] 3667 3668 return [ 3669 IndexMetadata( 3670 name, 3671 index_to_sql[name], 3672 index_columns[name], 3673 name in unique_indexes, 3674 table) 3675 for name in sorted(index_to_sql)] 3676 3677 def get_columns(self, table, schema=None): 3678 cursor = self.execute_sql('PRAGMA "%s".table_info("%s")' % 3679 (schema or 'main', table)) 3680 return [ColumnMetadata(r[1], r[2], not r[3], bool(r[5]), table, r[4]) 3681 for r in cursor.fetchall()] 3682 3683 def get_primary_keys(self, table, schema=None): 3684 cursor = self.execute_sql('PRAGMA "%s".table_info("%s")' % 3685 (schema or 'main', table)) 3686 return [row[1] for row in filter(lambda r: r[-1], cursor.fetchall())] 3687 3688 def get_foreign_keys(self, table, schema=None): 3689 cursor = self.execute_sql('PRAGMA "%s".foreign_key_list("%s")' % 3690 (schema or 'main', table)) 3691 return [ForeignKeyMetadata(row[3], row[2], row[4], table) 3692 for row in cursor.fetchall()] 3693 3694 def get_binary_type(self): 3695 return sqlite3.Binary 3696 3697 def conflict_statement(self, on_conflict, query): 3698 action = on_conflict._action.lower() if on_conflict._action else '' 3699 if action and action not in ('nothing', 'update'): 3700 return SQL('INSERT OR %s' % on_conflict._action.upper()) 3701 3702 def conflict_update(self, oc, query): 3703 # Sqlite prior to 3.24.0 does not support Postgres-style upsert. 3704 if self.server_version < (3, 24, 0) and \ 3705 any((oc._preserve, oc._update, oc._where, oc._conflict_target, 3706 oc._conflict_constraint)): 3707 raise ValueError('SQLite does not support specifying which values ' 3708 'to preserve or update.') 3709 3710 action = oc._action.lower() if oc._action else '' 3711 if action and action not in ('nothing', 'update', ''): 3712 return 3713 3714 if action == 'nothing': 3715 return SQL('ON CONFLICT DO NOTHING') 3716 elif not oc._update and not oc._preserve: 3717 raise ValueError('If you are not performing any updates (or ' 3718 'preserving any INSERTed values), then the ' 3719 'conflict resolution action should be set to ' 3720 '"NOTHING".') 3721 elif oc._conflict_constraint: 3722 raise ValueError('SQLite does not support specifying named ' 3723 'constraints for conflict resolution.') 3724 elif not oc._conflict_target: 3725 raise ValueError('SQLite requires that a conflict target be ' 3726 'specified when doing an upsert.') 3727 3728 return self._build_on_conflict_update(oc, query) 3729 3730 def extract_date(self, date_part, date_field): 3731 return fn.date_part(date_part, date_field, python_value=int) 3732 3733 def truncate_date(self, date_part, date_field): 3734 return fn.date_trunc(date_part, date_field, 3735 python_value=simple_date_time) 3736 3737 def to_timestamp(self, date_field): 3738 return fn.strftime('%s', date_field).cast('integer') 3739 3740 def from_timestamp(self, date_field): 3741 return fn.datetime(date_field, 'unixepoch') 3742 3743 3744class PostgresqlDatabase(Database): 3745 field_types = { 3746 'AUTO': 'SERIAL', 3747 'BIGAUTO': 'BIGSERIAL', 3748 'BLOB': 'BYTEA', 3749 'BOOL': 'BOOLEAN', 3750 'DATETIME': 'TIMESTAMP', 3751 'DECIMAL': 'NUMERIC', 3752 'DOUBLE': 'DOUBLE PRECISION', 3753 'UUID': 'UUID', 3754 'UUIDB': 'BYTEA'} 3755 operations = {'REGEXP': '~', 'IREGEXP': '~*'} 3756 param = '%s' 3757 3758 commit_select = True 3759 compound_select_parentheses = CSQ_PARENTHESES_ALWAYS 3760 for_update = True 3761 nulls_ordering = True 3762 returning_clause = True 3763 safe_create_index = False 3764 sequences = True 3765 3766 def init(self, database, register_unicode=True, encoding=None, 3767 isolation_level=None, **kwargs): 3768 self._register_unicode = register_unicode 3769 self._encoding = encoding 3770 self._isolation_level = isolation_level 3771 super(PostgresqlDatabase, self).init(database, **kwargs) 3772 3773 def _connect(self): 3774 if psycopg2 is None: 3775 raise ImproperlyConfigured('Postgres driver not installed!') 3776 conn = psycopg2.connect(database=self.database, **self.connect_params) 3777 if self._register_unicode: 3778 pg_extensions.register_type(pg_extensions.UNICODE, conn) 3779 pg_extensions.register_type(pg_extensions.UNICODEARRAY, conn) 3780 if self._encoding: 3781 conn.set_client_encoding(self._encoding) 3782 if self._isolation_level: 3783 conn.set_isolation_level(self._isolation_level) 3784 return conn 3785 3786 def _set_server_version(self, conn): 3787 self.server_version = conn.server_version 3788 if self.server_version >= 90600: 3789 self.safe_create_index = True 3790 3791 def is_connection_usable(self): 3792 if self._state.closed: 3793 return False 3794 3795 # Returns True if we are idle, running a command, or in an active 3796 # connection. If the connection is in an error state or the connection 3797 # is otherwise unusable, return False. 3798 txn_status = self._state.conn.get_transaction_status() 3799 return txn_status < pg_extensions.TRANSACTION_STATUS_INERROR 3800 3801 def last_insert_id(self, cursor, query_type=None): 3802 try: 3803 return cursor if query_type != Insert.SIMPLE else cursor[0][0] 3804 except (IndexError, KeyError, TypeError): 3805 pass 3806 3807 def get_tables(self, schema=None): 3808 query = ('SELECT tablename FROM pg_catalog.pg_tables ' 3809 'WHERE schemaname = %s ORDER BY tablename') 3810 cursor = self.execute_sql(query, (schema or 'public',)) 3811 return [table for table, in cursor.fetchall()] 3812 3813 def get_views(self, schema=None): 3814 query = ('SELECT viewname, definition FROM pg_catalog.pg_views ' 3815 'WHERE schemaname = %s ORDER BY viewname') 3816 cursor = self.execute_sql(query, (schema or 'public',)) 3817 return [ViewMetadata(view_name, sql.strip(' \t;')) 3818 for (view_name, sql) in cursor.fetchall()] 3819 3820 def get_indexes(self, table, schema=None): 3821 query = """ 3822 SELECT 3823 i.relname, idxs.indexdef, idx.indisunique, 3824 array_to_string(ARRAY( 3825 SELECT pg_get_indexdef(idx.indexrelid, k + 1, TRUE) 3826 FROM generate_subscripts(idx.indkey, 1) AS k 3827 ORDER BY k), ',') 3828 FROM pg_catalog.pg_class AS t 3829 INNER JOIN pg_catalog.pg_index AS idx ON t.oid = idx.indrelid 3830 INNER JOIN pg_catalog.pg_class AS i ON idx.indexrelid = i.oid 3831 INNER JOIN pg_catalog.pg_indexes AS idxs ON 3832 (idxs.tablename = t.relname AND idxs.indexname = i.relname) 3833 WHERE t.relname = %s AND t.relkind = %s AND idxs.schemaname = %s 3834 ORDER BY idx.indisunique DESC, i.relname;""" 3835 cursor = self.execute_sql(query, (table, 'r', schema or 'public')) 3836 return [IndexMetadata(name, sql.rstrip(' ;'), columns.split(','), 3837 is_unique, table) 3838 for name, sql, is_unique, columns in cursor.fetchall()] 3839 3840 def get_columns(self, table, schema=None): 3841 query = """ 3842 SELECT column_name, is_nullable, data_type, column_default 3843 FROM information_schema.columns 3844 WHERE table_name = %s AND table_schema = %s 3845 ORDER BY ordinal_position""" 3846 cursor = self.execute_sql(query, (table, schema or 'public')) 3847 pks = set(self.get_primary_keys(table, schema)) 3848 return [ColumnMetadata(name, dt, null == 'YES', name in pks, table, df) 3849 for name, null, dt, df in cursor.fetchall()] 3850 3851 def get_primary_keys(self, table, schema=None): 3852 query = """ 3853 SELECT kc.column_name 3854 FROM information_schema.table_constraints AS tc 3855 INNER JOIN information_schema.key_column_usage AS kc ON ( 3856 tc.table_name = kc.table_name AND 3857 tc.table_schema = kc.table_schema AND 3858 tc.constraint_name = kc.constraint_name) 3859 WHERE 3860 tc.constraint_type = %s AND 3861 tc.table_name = %s AND 3862 tc.table_schema = %s""" 3863 ctype = 'PRIMARY KEY' 3864 cursor = self.execute_sql(query, (ctype, table, schema or 'public')) 3865 return [pk for pk, in cursor.fetchall()] 3866 3867 def get_foreign_keys(self, table, schema=None): 3868 sql = """ 3869 SELECT DISTINCT 3870 kcu.column_name, ccu.table_name, ccu.column_name 3871 FROM information_schema.table_constraints AS tc 3872 JOIN information_schema.key_column_usage AS kcu 3873 ON (tc.constraint_name = kcu.constraint_name AND 3874 tc.constraint_schema = kcu.constraint_schema AND 3875 tc.table_name = kcu.table_name AND 3876 tc.table_schema = kcu.table_schema) 3877 JOIN information_schema.constraint_column_usage AS ccu 3878 ON (ccu.constraint_name = tc.constraint_name AND 3879 ccu.constraint_schema = tc.constraint_schema) 3880 WHERE 3881 tc.constraint_type = 'FOREIGN KEY' AND 3882 tc.table_name = %s AND 3883 tc.table_schema = %s""" 3884 cursor = self.execute_sql(sql, (table, schema or 'public')) 3885 return [ForeignKeyMetadata(row[0], row[1], row[2], table) 3886 for row in cursor.fetchall()] 3887 3888 def sequence_exists(self, sequence): 3889 res = self.execute_sql(""" 3890 SELECT COUNT(*) FROM pg_class, pg_namespace 3891 WHERE relkind='S' 3892 AND pg_class.relnamespace = pg_namespace.oid 3893 AND relname=%s""", (sequence,)) 3894 return bool(res.fetchone()[0]) 3895 3896 def get_binary_type(self): 3897 return psycopg2.Binary 3898 3899 def conflict_statement(self, on_conflict, query): 3900 return 3901 3902 def conflict_update(self, oc, query): 3903 action = oc._action.lower() if oc._action else '' 3904 if action in ('ignore', 'nothing'): 3905 return SQL('ON CONFLICT DO NOTHING') 3906 elif action and action != 'update': 3907 raise ValueError('The only supported actions for conflict ' 3908 'resolution with Postgresql are "ignore" or ' 3909 '"update".') 3910 elif not oc._update and not oc._preserve: 3911 raise ValueError('If you are not performing any updates (or ' 3912 'preserving any INSERTed values), then the ' 3913 'conflict resolution action should be set to ' 3914 '"IGNORE".') 3915 elif not (oc._conflict_target or oc._conflict_constraint): 3916 raise ValueError('Postgres requires that a conflict target be ' 3917 'specified when doing an upsert.') 3918 3919 return self._build_on_conflict_update(oc, query) 3920 3921 def extract_date(self, date_part, date_field): 3922 return fn.EXTRACT(NodeList((date_part, SQL('FROM'), date_field))) 3923 3924 def truncate_date(self, date_part, date_field): 3925 return fn.DATE_TRUNC(date_part, date_field) 3926 3927 def to_timestamp(self, date_field): 3928 return self.extract_date('EPOCH', date_field) 3929 3930 def from_timestamp(self, date_field): 3931 # Ironically, here, Postgres means "to the Postgresql timestamp type". 3932 return fn.to_timestamp(date_field) 3933 3934 def get_noop_select(self, ctx): 3935 return ctx.sql(Select().columns(SQL('0')).where(SQL('false'))) 3936 3937 def set_time_zone(self, timezone): 3938 self.execute_sql('set time zone "%s";' % timezone) 3939 3940 3941class MySQLDatabase(Database): 3942 field_types = { 3943 'AUTO': 'INTEGER AUTO_INCREMENT', 3944 'BIGAUTO': 'BIGINT AUTO_INCREMENT', 3945 'BOOL': 'BOOL', 3946 'DECIMAL': 'NUMERIC', 3947 'DOUBLE': 'DOUBLE PRECISION', 3948 'FLOAT': 'FLOAT', 3949 'UUID': 'VARCHAR(40)', 3950 'UUIDB': 'VARBINARY(16)'} 3951 operations = { 3952 'LIKE': 'LIKE BINARY', 3953 'ILIKE': 'LIKE', 3954 'REGEXP': 'REGEXP BINARY', 3955 'IREGEXP': 'REGEXP', 3956 'XOR': 'XOR'} 3957 param = '%s' 3958 quote = '``' 3959 3960 commit_select = True 3961 compound_select_parentheses = CSQ_PARENTHESES_UNNESTED 3962 for_update = True 3963 index_using_precedes_table = True 3964 limit_max = 2 ** 64 - 1 3965 safe_create_index = False 3966 safe_drop_index = False 3967 sql_mode = 'PIPES_AS_CONCAT' 3968 3969 def init(self, database, **kwargs): 3970 params = { 3971 'charset': 'utf8', 3972 'sql_mode': self.sql_mode, 3973 'use_unicode': True} 3974 params.update(kwargs) 3975 if 'password' in params and mysql_passwd: 3976 params['passwd'] = params.pop('password') 3977 super(MySQLDatabase, self).init(database, **params) 3978 3979 def _connect(self): 3980 if mysql is None: 3981 raise ImproperlyConfigured('MySQL driver not installed!') 3982 conn = mysql.connect(db=self.database, **self.connect_params) 3983 return conn 3984 3985 def _set_server_version(self, conn): 3986 try: 3987 version_raw = conn.server_version 3988 except AttributeError: 3989 version_raw = conn.get_server_info() 3990 self.server_version = self._extract_server_version(version_raw) 3991 3992 def _extract_server_version(self, version): 3993 version = version.lower() 3994 if 'maria' in version: 3995 match_obj = re.search(r'(1\d\.\d+\.\d+)', version) 3996 else: 3997 match_obj = re.search(r'(\d\.\d+\.\d+)', version) 3998 if match_obj is not None: 3999 return tuple(int(num) for num in match_obj.groups()[0].split('.')) 4000 4001 warnings.warn('Unable to determine MySQL version: "%s"' % version) 4002 return (0, 0, 0) # Unable to determine version! 4003 4004 def default_values_insert(self, ctx): 4005 return ctx.literal('() VALUES ()') 4006 4007 def get_tables(self, schema=None): 4008 query = ('SELECT table_name FROM information_schema.tables ' 4009 'WHERE table_schema = DATABASE() AND table_type != %s ' 4010 'ORDER BY table_name') 4011 return [table for table, in self.execute_sql(query, ('VIEW',))] 4012 4013 def get_views(self, schema=None): 4014 query = ('SELECT table_name, view_definition ' 4015 'FROM information_schema.views ' 4016 'WHERE table_schema = DATABASE() ORDER BY table_name') 4017 cursor = self.execute_sql(query) 4018 return [ViewMetadata(*row) for row in cursor.fetchall()] 4019 4020 def get_indexes(self, table, schema=None): 4021 cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table) 4022 unique = set() 4023 indexes = {} 4024 for row in cursor.fetchall(): 4025 if not row[1]: 4026 unique.add(row[2]) 4027 indexes.setdefault(row[2], []) 4028 indexes[row[2]].append(row[4]) 4029 return [IndexMetadata(name, None, indexes[name], name in unique, table) 4030 for name in indexes] 4031 4032 def get_columns(self, table, schema=None): 4033 sql = """ 4034 SELECT column_name, is_nullable, data_type, column_default 4035 FROM information_schema.columns 4036 WHERE table_name = %s AND table_schema = DATABASE()""" 4037 cursor = self.execute_sql(sql, (table,)) 4038 pks = set(self.get_primary_keys(table)) 4039 return [ColumnMetadata(name, dt, null == 'YES', name in pks, table, df) 4040 for name, null, dt, df in cursor.fetchall()] 4041 4042 def get_primary_keys(self, table, schema=None): 4043 cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table) 4044 return [row[4] for row in 4045 filter(lambda row: row[2] == 'PRIMARY', cursor.fetchall())] 4046 4047 def get_foreign_keys(self, table, schema=None): 4048 query = """ 4049 SELECT column_name, referenced_table_name, referenced_column_name 4050 FROM information_schema.key_column_usage 4051 WHERE table_name = %s 4052 AND table_schema = DATABASE() 4053 AND referenced_table_name IS NOT NULL 4054 AND referenced_column_name IS NOT NULL""" 4055 cursor = self.execute_sql(query, (table,)) 4056 return [ 4057 ForeignKeyMetadata(column, dest_table, dest_column, table) 4058 for column, dest_table, dest_column in cursor.fetchall()] 4059 4060 def get_binary_type(self): 4061 return mysql.Binary 4062 4063 def conflict_statement(self, on_conflict, query): 4064 if not on_conflict._action: return 4065 4066 action = on_conflict._action.lower() 4067 if action == 'replace': 4068 return SQL('REPLACE') 4069 elif action == 'ignore': 4070 return SQL('INSERT IGNORE') 4071 elif action != 'update': 4072 raise ValueError('Un-supported action for conflict resolution. ' 4073 'MySQL supports REPLACE, IGNORE and UPDATE.') 4074 4075 def conflict_update(self, on_conflict, query): 4076 if on_conflict._where or on_conflict._conflict_target or \ 4077 on_conflict._conflict_constraint: 4078 raise ValueError('MySQL does not support the specification of ' 4079 'where clauses or conflict targets for conflict ' 4080 'resolution.') 4081 4082 updates = [] 4083 if on_conflict._preserve: 4084 # Here we need to determine which function to use, which varies 4085 # depending on the MySQL server version. MySQL and MariaDB prior to 4086 # 10.3.3 use "VALUES", while MariaDB 10.3.3+ use "VALUE". 4087 version = self.server_version or (0,) 4088 if version[0] == 10 and version >= (10, 3, 3): 4089 VALUE_FN = fn.VALUE 4090 else: 4091 VALUE_FN = fn.VALUES 4092 4093 for column in on_conflict._preserve: 4094 entity = ensure_entity(column) 4095 expression = NodeList(( 4096 ensure_entity(column), 4097 SQL('='), 4098 VALUE_FN(entity))) 4099 updates.append(expression) 4100 4101 if on_conflict._update: 4102 for k, v in on_conflict._update.items(): 4103 if not isinstance(v, Node): 4104 # Attempt to resolve string field-names to their respective 4105 # field object, to apply data-type conversions. 4106 if isinstance(k, basestring): 4107 k = getattr(query.table, k) 4108 if isinstance(k, Field): 4109 v = k.to_value(v) 4110 else: 4111 v = Value(v, unpack=False) 4112 updates.append(NodeList((ensure_entity(k), SQL('='), v))) 4113 4114 if updates: 4115 return NodeList((SQL('ON DUPLICATE KEY UPDATE'), 4116 CommaNodeList(updates))) 4117 4118 def extract_date(self, date_part, date_field): 4119 return fn.EXTRACT(NodeList((SQL(date_part), SQL('FROM'), date_field))) 4120 4121 def truncate_date(self, date_part, date_field): 4122 return fn.DATE_FORMAT(date_field, __mysql_date_trunc__[date_part], 4123 python_value=simple_date_time) 4124 4125 def to_timestamp(self, date_field): 4126 return fn.UNIX_TIMESTAMP(date_field) 4127 4128 def from_timestamp(self, date_field): 4129 return fn.FROM_UNIXTIME(date_field) 4130 4131 def random(self): 4132 return fn.rand() 4133 4134 def get_noop_select(self, ctx): 4135 return ctx.literal('DO 0') 4136 4137 4138# TRANSACTION CONTROL. 4139 4140 4141class _manual(_callable_context_manager): 4142 def __init__(self, db): 4143 self.db = db 4144 4145 def __enter__(self): 4146 top = self.db.top_transaction() 4147 if top is not None and not isinstance(top, _manual): 4148 raise ValueError('Cannot enter manual commit block while a ' 4149 'transaction is active.') 4150 self.db.push_transaction(self) 4151 4152 def __exit__(self, exc_type, exc_val, exc_tb): 4153 if self.db.pop_transaction() is not self: 4154 raise ValueError('Transaction stack corrupted while exiting ' 4155 'manual commit block.') 4156 4157 4158class _atomic(_callable_context_manager): 4159 def __init__(self, db, *args, **kwargs): 4160 self.db = db 4161 self._transaction_args = (args, kwargs) 4162 4163 def __enter__(self): 4164 if self.db.transaction_depth() == 0: 4165 args, kwargs = self._transaction_args 4166 self._helper = self.db.transaction(*args, **kwargs) 4167 elif isinstance(self.db.top_transaction(), _manual): 4168 raise ValueError('Cannot enter atomic commit block while in ' 4169 'manual commit mode.') 4170 else: 4171 self._helper = self.db.savepoint() 4172 return self._helper.__enter__() 4173 4174 def __exit__(self, exc_type, exc_val, exc_tb): 4175 return self._helper.__exit__(exc_type, exc_val, exc_tb) 4176 4177 4178class _transaction(_callable_context_manager): 4179 def __init__(self, db, *args, **kwargs): 4180 self.db = db 4181 self._begin_args = (args, kwargs) 4182 4183 def _begin(self): 4184 args, kwargs = self._begin_args 4185 self.db.begin(*args, **kwargs) 4186 4187 def commit(self, begin=True): 4188 self.db.commit() 4189 if begin: 4190 self._begin() 4191 4192 def rollback(self, begin=True): 4193 self.db.rollback() 4194 if begin: 4195 self._begin() 4196 4197 def __enter__(self): 4198 if self.db.transaction_depth() == 0: 4199 self._begin() 4200 self.db.push_transaction(self) 4201 return self 4202 4203 def __exit__(self, exc_type, exc_val, exc_tb): 4204 try: 4205 if exc_type: 4206 self.rollback(False) 4207 elif self.db.transaction_depth() == 1: 4208 try: 4209 self.commit(False) 4210 except: 4211 self.rollback(False) 4212 raise 4213 finally: 4214 self.db.pop_transaction() 4215 4216 4217class _savepoint(_callable_context_manager): 4218 def __init__(self, db, sid=None): 4219 self.db = db 4220 self.sid = sid or 's' + uuid.uuid4().hex 4221 self.quoted_sid = self.sid.join(self.db.quote) 4222 4223 def _begin(self): 4224 self.db.execute_sql('SAVEPOINT %s;' % self.quoted_sid) 4225 4226 def commit(self, begin=True): 4227 self.db.execute_sql('RELEASE SAVEPOINT %s;' % self.quoted_sid) 4228 if begin: self._begin() 4229 4230 def rollback(self): 4231 self.db.execute_sql('ROLLBACK TO SAVEPOINT %s;' % self.quoted_sid) 4232 4233 def __enter__(self): 4234 self._begin() 4235 return self 4236 4237 def __exit__(self, exc_type, exc_val, exc_tb): 4238 if exc_type: 4239 self.rollback() 4240 else: 4241 try: 4242 self.commit(begin=False) 4243 except: 4244 self.rollback() 4245 raise 4246 4247 4248# CURSOR REPRESENTATIONS. 4249 4250 4251class CursorWrapper(object): 4252 def __init__(self, cursor): 4253 self.cursor = cursor 4254 self.count = 0 4255 self.index = 0 4256 self.initialized = False 4257 self.populated = False 4258 self.row_cache = [] 4259 4260 def __iter__(self): 4261 if self.populated: 4262 return iter(self.row_cache) 4263 return ResultIterator(self) 4264 4265 def __getitem__(self, item): 4266 if isinstance(item, slice): 4267 stop = item.stop 4268 if stop is None or stop < 0: 4269 self.fill_cache() 4270 else: 4271 self.fill_cache(stop) 4272 return self.row_cache[item] 4273 elif isinstance(item, int): 4274 self.fill_cache(item if item > 0 else 0) 4275 return self.row_cache[item] 4276 else: 4277 raise ValueError('CursorWrapper only supports integer and slice ' 4278 'indexes.') 4279 4280 def __len__(self): 4281 self.fill_cache() 4282 return self.count 4283 4284 def initialize(self): 4285 pass 4286 4287 def iterate(self, cache=True): 4288 row = self.cursor.fetchone() 4289 if row is None: 4290 self.populated = True 4291 self.cursor.close() 4292 raise StopIteration 4293 elif not self.initialized: 4294 self.initialize() # Lazy initialization. 4295 self.initialized = True 4296 self.count += 1 4297 result = self.process_row(row) 4298 if cache: 4299 self.row_cache.append(result) 4300 return result 4301 4302 def process_row(self, row): 4303 return row 4304 4305 def iterator(self): 4306 """Efficient one-pass iteration over the result set.""" 4307 while True: 4308 try: 4309 yield self.iterate(False) 4310 except StopIteration: 4311 return 4312 4313 def fill_cache(self, n=0): 4314 n = n or float('Inf') 4315 if n < 0: 4316 raise ValueError('Negative values are not supported.') 4317 4318 iterator = ResultIterator(self) 4319 iterator.index = self.count 4320 while not self.populated and (n > self.count): 4321 try: 4322 iterator.next() 4323 except StopIteration: 4324 break 4325 4326 4327class DictCursorWrapper(CursorWrapper): 4328 def _initialize_columns(self): 4329 description = self.cursor.description 4330 self.columns = [t[0][t[0].find('.') + 1:].strip('")') 4331 for t in description] 4332 self.ncols = len(description) 4333 4334 initialize = _initialize_columns 4335 4336 def _row_to_dict(self, row): 4337 result = {} 4338 for i in range(self.ncols): 4339 result.setdefault(self.columns[i], row[i]) # Do not overwrite. 4340 return result 4341 4342 process_row = _row_to_dict 4343 4344 4345class NamedTupleCursorWrapper(CursorWrapper): 4346 def initialize(self): 4347 description = self.cursor.description 4348 self.tuple_class = collections.namedtuple( 4349 'Row', 4350 [col[0][col[0].find('.') + 1:].strip('"') for col in description]) 4351 4352 def process_row(self, row): 4353 return self.tuple_class(*row) 4354 4355 4356class ObjectCursorWrapper(DictCursorWrapper): 4357 def __init__(self, cursor, constructor): 4358 super(ObjectCursorWrapper, self).__init__(cursor) 4359 self.constructor = constructor 4360 4361 def process_row(self, row): 4362 row_dict = self._row_to_dict(row) 4363 return self.constructor(**row_dict) 4364 4365 4366class ResultIterator(object): 4367 def __init__(self, cursor_wrapper): 4368 self.cursor_wrapper = cursor_wrapper 4369 self.index = 0 4370 4371 def __iter__(self): 4372 return self 4373 4374 def next(self): 4375 if self.index < self.cursor_wrapper.count: 4376 obj = self.cursor_wrapper.row_cache[self.index] 4377 elif not self.cursor_wrapper.populated: 4378 self.cursor_wrapper.iterate() 4379 obj = self.cursor_wrapper.row_cache[self.index] 4380 else: 4381 raise StopIteration 4382 self.index += 1 4383 return obj 4384 4385 __next__ = next 4386 4387# FIELDS 4388 4389class FieldAccessor(object): 4390 def __init__(self, model, field, name): 4391 self.model = model 4392 self.field = field 4393 self.name = name 4394 4395 def __get__(self, instance, instance_type=None): 4396 if instance is not None: 4397 return instance.__data__.get(self.name) 4398 return self.field 4399 4400 def __set__(self, instance, value): 4401 instance.__data__[self.name] = value 4402 instance._dirty.add(self.name) 4403 4404 4405class ForeignKeyAccessor(FieldAccessor): 4406 def __init__(self, model, field, name): 4407 super(ForeignKeyAccessor, self).__init__(model, field, name) 4408 self.rel_model = field.rel_model 4409 4410 def get_rel_instance(self, instance): 4411 value = instance.__data__.get(self.name) 4412 if value is not None or self.name in instance.__rel__: 4413 if self.name not in instance.__rel__ and self.field.lazy_load: 4414 obj = self.rel_model.get(self.field.rel_field == value) 4415 instance.__rel__[self.name] = obj 4416 return instance.__rel__.get(self.name, value) 4417 elif not self.field.null: 4418 raise self.rel_model.DoesNotExist 4419 return value 4420 4421 def __get__(self, instance, instance_type=None): 4422 if instance is not None: 4423 return self.get_rel_instance(instance) 4424 return self.field 4425 4426 def __set__(self, instance, obj): 4427 if isinstance(obj, self.rel_model): 4428 instance.__data__[self.name] = getattr(obj, self.field.rel_field.name) 4429 instance.__rel__[self.name] = obj 4430 else: 4431 fk_value = instance.__data__.get(self.name) 4432 instance.__data__[self.name] = obj 4433 if obj != fk_value and self.name in instance.__rel__: 4434 del instance.__rel__[self.name] 4435 instance._dirty.add(self.name) 4436 4437 4438class BackrefAccessor(object): 4439 def __init__(self, field): 4440 self.field = field 4441 self.model = field.rel_model 4442 self.rel_model = field.model 4443 4444 def __get__(self, instance, instance_type=None): 4445 if instance is not None: 4446 dest = self.field.rel_field.name 4447 return (self.rel_model 4448 .select() 4449 .where(self.field == getattr(instance, dest))) 4450 return self 4451 4452 4453class ObjectIdAccessor(object): 4454 """Gives direct access to the underlying id""" 4455 def __init__(self, field): 4456 self.field = field 4457 4458 def __get__(self, instance, instance_type=None): 4459 if instance is not None: 4460 value = instance.__data__.get(self.field.name) 4461 # Pull the object-id from the related object if it is not set. 4462 if value is None and self.field.name in instance.__rel__: 4463 rel_obj = instance.__rel__[self.field.name] 4464 value = getattr(rel_obj, self.field.rel_field.name) 4465 return value 4466 return self.field 4467 4468 def __set__(self, instance, value): 4469 setattr(instance, self.field.name, value) 4470 4471 4472class Field(ColumnBase): 4473 _field_counter = 0 4474 _order = 0 4475 accessor_class = FieldAccessor 4476 auto_increment = False 4477 default_index_type = None 4478 field_type = 'DEFAULT' 4479 unpack = True 4480 4481 def __init__(self, null=False, index=False, unique=False, column_name=None, 4482 default=None, primary_key=False, constraints=None, 4483 sequence=None, collation=None, unindexed=False, choices=None, 4484 help_text=None, verbose_name=None, index_type=None, 4485 db_column=None, _hidden=False): 4486 if db_column is not None: 4487 __deprecated__('"db_column" has been deprecated in favor of ' 4488 '"column_name" for Field objects.') 4489 column_name = db_column 4490 4491 self.null = null 4492 self.index = index 4493 self.unique = unique 4494 self.column_name = column_name 4495 self.default = default 4496 self.primary_key = primary_key 4497 self.constraints = constraints # List of column constraints. 4498 self.sequence = sequence # Name of sequence, e.g. foo_id_seq. 4499 self.collation = collation 4500 self.unindexed = unindexed 4501 self.choices = choices 4502 self.help_text = help_text 4503 self.verbose_name = verbose_name 4504 self.index_type = index_type or self.default_index_type 4505 self._hidden = _hidden 4506 4507 # Used internally for recovering the order in which Fields were defined 4508 # on the Model class. 4509 Field._field_counter += 1 4510 self._order = Field._field_counter 4511 self._sort_key = (self.primary_key and 1 or 2), self._order 4512 4513 def __hash__(self): 4514 return hash(self.name + '.' + self.model.__name__) 4515 4516 def __repr__(self): 4517 if hasattr(self, 'model') and getattr(self, 'name', None): 4518 return '<%s: %s.%s>' % (type(self).__name__, 4519 self.model.__name__, 4520 self.name) 4521 return '<%s: (unbound)>' % type(self).__name__ 4522 4523 def bind(self, model, name, set_attribute=True): 4524 self.model = model 4525 self.name = self.safe_name = name 4526 self.column_name = self.column_name or name 4527 if set_attribute: 4528 setattr(model, name, self.accessor_class(model, self, name)) 4529 4530 @property 4531 def column(self): 4532 return Column(self.model._meta.table, self.column_name) 4533 4534 def adapt(self, value): 4535 return value 4536 4537 def db_value(self, value): 4538 return value if value is None else self.adapt(value) 4539 4540 def python_value(self, value): 4541 return value if value is None else self.adapt(value) 4542 4543 def to_value(self, value): 4544 return Value(value, self.db_value, unpack=False) 4545 4546 def get_sort_key(self, ctx): 4547 return self._sort_key 4548 4549 def __sql__(self, ctx): 4550 return ctx.sql(self.column) 4551 4552 def get_modifiers(self): 4553 pass 4554 4555 def ddl_datatype(self, ctx): 4556 if ctx and ctx.state.field_types: 4557 column_type = ctx.state.field_types.get(self.field_type, 4558 self.field_type) 4559 else: 4560 column_type = self.field_type 4561 4562 modifiers = self.get_modifiers() 4563 if column_type and modifiers: 4564 modifier_literal = ', '.join([str(m) for m in modifiers]) 4565 return SQL('%s(%s)' % (column_type, modifier_literal)) 4566 else: 4567 return SQL(column_type) 4568 4569 def ddl(self, ctx): 4570 accum = [Entity(self.column_name)] 4571 data_type = self.ddl_datatype(ctx) 4572 if data_type: 4573 accum.append(data_type) 4574 if self.unindexed: 4575 accum.append(SQL('UNINDEXED')) 4576 if not self.null: 4577 accum.append(SQL('NOT NULL')) 4578 if self.primary_key: 4579 accum.append(SQL('PRIMARY KEY')) 4580 if self.sequence: 4581 accum.append(SQL("DEFAULT NEXTVAL('%s')" % self.sequence)) 4582 if self.constraints: 4583 accum.extend(self.constraints) 4584 if self.collation: 4585 accum.append(SQL('COLLATE %s' % self.collation)) 4586 return NodeList(accum) 4587 4588 4589class IntegerField(Field): 4590 field_type = 'INT' 4591 4592 def adapt(self, value): 4593 try: 4594 return int(value) 4595 except ValueError: 4596 return value 4597 4598 4599class BigIntegerField(IntegerField): 4600 field_type = 'BIGINT' 4601 4602 4603class SmallIntegerField(IntegerField): 4604 field_type = 'SMALLINT' 4605 4606 4607class AutoField(IntegerField): 4608 auto_increment = True 4609 field_type = 'AUTO' 4610 4611 def __init__(self, *args, **kwargs): 4612 if kwargs.get('primary_key') is False: 4613 raise ValueError('%s must always be a primary key.' % type(self)) 4614 kwargs['primary_key'] = True 4615 super(AutoField, self).__init__(*args, **kwargs) 4616 4617 4618class BigAutoField(AutoField): 4619 field_type = 'BIGAUTO' 4620 4621 4622class IdentityField(AutoField): 4623 field_type = 'INT GENERATED BY DEFAULT AS IDENTITY' 4624 4625 def __init__(self, generate_always=False, **kwargs): 4626 if generate_always: 4627 self.field_type = 'INT GENERATED ALWAYS AS IDENTITY' 4628 super(IdentityField, self).__init__(**kwargs) 4629 4630 4631class PrimaryKeyField(AutoField): 4632 def __init__(self, *args, **kwargs): 4633 __deprecated__('"PrimaryKeyField" has been renamed to "AutoField". ' 4634 'Please update your code accordingly as this will be ' 4635 'completely removed in a subsequent release.') 4636 super(PrimaryKeyField, self).__init__(*args, **kwargs) 4637 4638 4639class FloatField(Field): 4640 field_type = 'FLOAT' 4641 4642 def adapt(self, value): 4643 try: 4644 return float(value) 4645 except ValueError: 4646 return value 4647 4648 4649class DoubleField(FloatField): 4650 field_type = 'DOUBLE' 4651 4652 4653class DecimalField(Field): 4654 field_type = 'DECIMAL' 4655 4656 def __init__(self, max_digits=10, decimal_places=5, auto_round=False, 4657 rounding=None, *args, **kwargs): 4658 self.max_digits = max_digits 4659 self.decimal_places = decimal_places 4660 self.auto_round = auto_round 4661 self.rounding = rounding or decimal.DefaultContext.rounding 4662 self._exp = decimal.Decimal(10) ** (-self.decimal_places) 4663 super(DecimalField, self).__init__(*args, **kwargs) 4664 4665 def get_modifiers(self): 4666 return [self.max_digits, self.decimal_places] 4667 4668 def db_value(self, value): 4669 D = decimal.Decimal 4670 if not value: 4671 return value if value is None else D(0) 4672 if self.auto_round: 4673 decimal_value = D(text_type(value)) 4674 return decimal_value.quantize(self._exp, rounding=self.rounding) 4675 return value 4676 4677 def python_value(self, value): 4678 if value is not None: 4679 if isinstance(value, decimal.Decimal): 4680 return value 4681 return decimal.Decimal(text_type(value)) 4682 4683 4684class _StringField(Field): 4685 def adapt(self, value): 4686 if isinstance(value, text_type): 4687 return value 4688 elif isinstance(value, bytes_type): 4689 return value.decode('utf-8') 4690 return text_type(value) 4691 4692 def __add__(self, other): return StringExpression(self, OP.CONCAT, other) 4693 def __radd__(self, other): return StringExpression(other, OP.CONCAT, self) 4694 4695 4696class CharField(_StringField): 4697 field_type = 'VARCHAR' 4698 4699 def __init__(self, max_length=255, *args, **kwargs): 4700 self.max_length = max_length 4701 super(CharField, self).__init__(*args, **kwargs) 4702 4703 def get_modifiers(self): 4704 return self.max_length and [self.max_length] or None 4705 4706 4707class FixedCharField(CharField): 4708 field_type = 'CHAR' 4709 4710 def python_value(self, value): 4711 value = super(FixedCharField, self).python_value(value) 4712 if value: 4713 value = value.strip() 4714 return value 4715 4716 4717class TextField(_StringField): 4718 field_type = 'TEXT' 4719 4720 4721class BlobField(Field): 4722 field_type = 'BLOB' 4723 4724 def _db_hook(self, database): 4725 if database is None: 4726 self._constructor = bytearray 4727 else: 4728 self._constructor = database.get_binary_type() 4729 4730 def bind(self, model, name, set_attribute=True): 4731 self._constructor = bytearray 4732 if model._meta.database: 4733 if isinstance(model._meta.database, Proxy): 4734 model._meta.database.attach_callback(self._db_hook) 4735 else: 4736 self._db_hook(model._meta.database) 4737 4738 # Attach a hook to the model metadata; in the event the database is 4739 # changed or set at run-time, we will be sure to apply our callback and 4740 # use the proper data-type for our database driver. 4741 model._meta._db_hooks.append(self._db_hook) 4742 return super(BlobField, self).bind(model, name, set_attribute) 4743 4744 def db_value(self, value): 4745 if isinstance(value, text_type): 4746 value = value.encode('raw_unicode_escape') 4747 if isinstance(value, bytes_type): 4748 return self._constructor(value) 4749 return value 4750 4751 4752class BitField(BitwiseMixin, BigIntegerField): 4753 def __init__(self, *args, **kwargs): 4754 kwargs.setdefault('default', 0) 4755 super(BitField, self).__init__(*args, **kwargs) 4756 self.__current_flag = 1 4757 4758 def flag(self, value=None): 4759 if value is None: 4760 value = self.__current_flag 4761 self.__current_flag <<= 1 4762 else: 4763 self.__current_flag = value << 1 4764 4765 class FlagDescriptor(ColumnBase): 4766 def __init__(self, field, value): 4767 self._field = field 4768 self._value = value 4769 super(FlagDescriptor, self).__init__() 4770 def clear(self): 4771 return self._field.bin_and(~self._value) 4772 def set(self): 4773 return self._field.bin_or(self._value) 4774 def __get__(self, instance, instance_type=None): 4775 if instance is None: 4776 return self 4777 value = getattr(instance, self._field.name) or 0 4778 return (value & self._value) != 0 4779 def __set__(self, instance, is_set): 4780 if is_set not in (True, False): 4781 raise ValueError('Value must be either True or False') 4782 value = getattr(instance, self._field.name) or 0 4783 if is_set: 4784 value |= self._value 4785 else: 4786 value &= ~self._value 4787 setattr(instance, self._field.name, value) 4788 def __sql__(self, ctx): 4789 return ctx.sql(self._field.bin_and(self._value) != 0) 4790 return FlagDescriptor(self, value) 4791 4792 4793class BigBitFieldData(object): 4794 def __init__(self, instance, name): 4795 self.instance = instance 4796 self.name = name 4797 value = self.instance.__data__.get(self.name) 4798 if not value: 4799 value = bytearray() 4800 elif not isinstance(value, bytearray): 4801 value = bytearray(value) 4802 self._buffer = self.instance.__data__[self.name] = value 4803 4804 def _ensure_length(self, idx): 4805 byte_num, byte_offset = divmod(idx, 8) 4806 cur_size = len(self._buffer) 4807 if cur_size <= byte_num: 4808 self._buffer.extend(b'\x00' * ((byte_num + 1) - cur_size)) 4809 return byte_num, byte_offset 4810 4811 def set_bit(self, idx): 4812 byte_num, byte_offset = self._ensure_length(idx) 4813 self._buffer[byte_num] |= (1 << byte_offset) 4814 4815 def clear_bit(self, idx): 4816 byte_num, byte_offset = self._ensure_length(idx) 4817 self._buffer[byte_num] &= ~(1 << byte_offset) 4818 4819 def toggle_bit(self, idx): 4820 byte_num, byte_offset = self._ensure_length(idx) 4821 self._buffer[byte_num] ^= (1 << byte_offset) 4822 return bool(self._buffer[byte_num] & (1 << byte_offset)) 4823 4824 def is_set(self, idx): 4825 byte_num, byte_offset = self._ensure_length(idx) 4826 return bool(self._buffer[byte_num] & (1 << byte_offset)) 4827 4828 def __repr__(self): 4829 return repr(self._buffer) 4830 4831 4832class BigBitFieldAccessor(FieldAccessor): 4833 def __get__(self, instance, instance_type=None): 4834 if instance is None: 4835 return self.field 4836 return BigBitFieldData(instance, self.name) 4837 def __set__(self, instance, value): 4838 if isinstance(value, memoryview): 4839 value = value.tobytes() 4840 elif isinstance(value, buffer_type): 4841 value = bytes(value) 4842 elif isinstance(value, bytearray): 4843 value = bytes_type(value) 4844 elif isinstance(value, BigBitFieldData): 4845 value = bytes_type(value._buffer) 4846 elif isinstance(value, text_type): 4847 value = value.encode('utf-8') 4848 elif not isinstance(value, bytes_type): 4849 raise ValueError('Value must be either a bytes, memoryview or ' 4850 'BigBitFieldData instance.') 4851 super(BigBitFieldAccessor, self).__set__(instance, value) 4852 4853 4854class BigBitField(BlobField): 4855 accessor_class = BigBitFieldAccessor 4856 4857 def __init__(self, *args, **kwargs): 4858 kwargs.setdefault('default', bytes_type) 4859 super(BigBitField, self).__init__(*args, **kwargs) 4860 4861 def db_value(self, value): 4862 return bytes_type(value) if value is not None else value 4863 4864 4865class UUIDField(Field): 4866 field_type = 'UUID' 4867 4868 def db_value(self, value): 4869 if isinstance(value, basestring) and len(value) == 32: 4870 # Hex string. No transformation is necessary. 4871 return value 4872 elif isinstance(value, bytes) and len(value) == 16: 4873 # Allow raw binary representation. 4874 value = uuid.UUID(bytes=value) 4875 if isinstance(value, uuid.UUID): 4876 return value.hex 4877 try: 4878 return uuid.UUID(value).hex 4879 except: 4880 return value 4881 4882 def python_value(self, value): 4883 if isinstance(value, uuid.UUID): 4884 return value 4885 return uuid.UUID(value) if value is not None else None 4886 4887 4888class BinaryUUIDField(BlobField): 4889 field_type = 'UUIDB' 4890 4891 def db_value(self, value): 4892 if isinstance(value, bytes) and len(value) == 16: 4893 # Raw binary value. No transformation is necessary. 4894 return self._constructor(value) 4895 elif isinstance(value, basestring) and len(value) == 32: 4896 # Allow hex string representation. 4897 value = uuid.UUID(hex=value) 4898 if isinstance(value, uuid.UUID): 4899 return self._constructor(value.bytes) 4900 elif value is not None: 4901 raise ValueError('value for binary UUID field must be UUID(), ' 4902 'a hexadecimal string, or a bytes object.') 4903 4904 def python_value(self, value): 4905 if isinstance(value, uuid.UUID): 4906 return value 4907 elif isinstance(value, memoryview): 4908 value = value.tobytes() 4909 elif value and not isinstance(value, bytes): 4910 value = bytes(value) 4911 return uuid.UUID(bytes=value) if value is not None else None 4912 4913 4914def _date_part(date_part): 4915 def dec(self): 4916 return self.model._meta.database.extract_date(date_part, self) 4917 return dec 4918 4919def format_date_time(value, formats, post_process=None): 4920 post_process = post_process or (lambda x: x) 4921 for fmt in formats: 4922 try: 4923 return post_process(datetime.datetime.strptime(value, fmt)) 4924 except ValueError: 4925 pass 4926 return value 4927 4928def simple_date_time(value): 4929 try: 4930 return datetime.datetime.strptime(value, '%Y-%m-%d %H:%M:%S') 4931 except (TypeError, ValueError): 4932 return value 4933 4934 4935class _BaseFormattedField(Field): 4936 formats = None 4937 4938 def __init__(self, formats=None, *args, **kwargs): 4939 if formats is not None: 4940 self.formats = formats 4941 super(_BaseFormattedField, self).__init__(*args, **kwargs) 4942 4943 4944class DateTimeField(_BaseFormattedField): 4945 field_type = 'DATETIME' 4946 formats = [ 4947 '%Y-%m-%d %H:%M:%S.%f', 4948 '%Y-%m-%d %H:%M:%S', 4949 '%Y-%m-%d', 4950 ] 4951 4952 def adapt(self, value): 4953 if value and isinstance(value, basestring): 4954 return format_date_time(value, self.formats) 4955 return value 4956 4957 def to_timestamp(self): 4958 return self.model._meta.database.to_timestamp(self) 4959 4960 def truncate(self, part): 4961 return self.model._meta.database.truncate_date(part, self) 4962 4963 year = property(_date_part('year')) 4964 month = property(_date_part('month')) 4965 day = property(_date_part('day')) 4966 hour = property(_date_part('hour')) 4967 minute = property(_date_part('minute')) 4968 second = property(_date_part('second')) 4969 4970 4971class DateField(_BaseFormattedField): 4972 field_type = 'DATE' 4973 formats = [ 4974 '%Y-%m-%d', 4975 '%Y-%m-%d %H:%M:%S', 4976 '%Y-%m-%d %H:%M:%S.%f', 4977 ] 4978 4979 def adapt(self, value): 4980 if value and isinstance(value, basestring): 4981 pp = lambda x: x.date() 4982 return format_date_time(value, self.formats, pp) 4983 elif value and isinstance(value, datetime.datetime): 4984 return value.date() 4985 return value 4986 4987 def to_timestamp(self): 4988 return self.model._meta.database.to_timestamp(self) 4989 4990 def truncate(self, part): 4991 return self.model._meta.database.truncate_date(part, self) 4992 4993 year = property(_date_part('year')) 4994 month = property(_date_part('month')) 4995 day = property(_date_part('day')) 4996 4997 4998class TimeField(_BaseFormattedField): 4999 field_type = 'TIME' 5000 formats = [ 5001 '%H:%M:%S.%f', 5002 '%H:%M:%S', 5003 '%H:%M', 5004 '%Y-%m-%d %H:%M:%S.%f', 5005 '%Y-%m-%d %H:%M:%S', 5006 ] 5007 5008 def adapt(self, value): 5009 if value: 5010 if isinstance(value, basestring): 5011 pp = lambda x: x.time() 5012 return format_date_time(value, self.formats, pp) 5013 elif isinstance(value, datetime.datetime): 5014 return value.time() 5015 if value is not None and isinstance(value, datetime.timedelta): 5016 return (datetime.datetime.min + value).time() 5017 return value 5018 5019 hour = property(_date_part('hour')) 5020 minute = property(_date_part('minute')) 5021 second = property(_date_part('second')) 5022 5023 5024def _timestamp_date_part(date_part): 5025 def dec(self): 5026 db = self.model._meta.database 5027 expr = ((self / Value(self.resolution, converter=False)) 5028 if self.resolution > 1 else self) 5029 return db.extract_date(date_part, db.from_timestamp(expr)) 5030 return dec 5031 5032 5033class TimestampField(BigIntegerField): 5034 # Support second -> microsecond resolution. 5035 valid_resolutions = [10**i for i in range(7)] 5036 5037 def __init__(self, *args, **kwargs): 5038 self.resolution = kwargs.pop('resolution', None) 5039 5040 if not self.resolution: 5041 self.resolution = 1 5042 elif self.resolution in range(2, 7): 5043 self.resolution = 10 ** self.resolution 5044 elif self.resolution not in self.valid_resolutions: 5045 raise ValueError('TimestampField resolution must be one of: %s' % 5046 ', '.join(str(i) for i in self.valid_resolutions)) 5047 self.ticks_to_microsecond = 1000000 // self.resolution 5048 5049 self.utc = kwargs.pop('utc', False) or False 5050 dflt = datetime.datetime.utcnow if self.utc else datetime.datetime.now 5051 kwargs.setdefault('default', dflt) 5052 super(TimestampField, self).__init__(*args, **kwargs) 5053 5054 def local_to_utc(self, dt): 5055 # Convert naive local datetime into naive UTC, e.g.: 5056 # 2019-03-01T12:00:00 (local=US/Central) -> 2019-03-01T18:00:00. 5057 # 2019-05-01T12:00:00 (local=US/Central) -> 2019-05-01T17:00:00. 5058 # 2019-03-01T12:00:00 (local=UTC) -> 2019-03-01T12:00:00. 5059 return datetime.datetime(*time.gmtime(time.mktime(dt.timetuple()))[:6]) 5060 5061 def utc_to_local(self, dt): 5062 # Convert a naive UTC datetime into local time, e.g.: 5063 # 2019-03-01T18:00:00 (local=US/Central) -> 2019-03-01T12:00:00. 5064 # 2019-05-01T17:00:00 (local=US/Central) -> 2019-05-01T12:00:00. 5065 # 2019-03-01T12:00:00 (local=UTC) -> 2019-03-01T12:00:00. 5066 ts = calendar.timegm(dt.utctimetuple()) 5067 return datetime.datetime.fromtimestamp(ts) 5068 5069 def get_timestamp(self, value): 5070 if self.utc: 5071 # If utc-mode is on, then we assume all naive datetimes are in UTC. 5072 return calendar.timegm(value.utctimetuple()) 5073 else: 5074 return time.mktime(value.timetuple()) 5075 5076 def db_value(self, value): 5077 if value is None: 5078 return 5079 5080 if isinstance(value, datetime.datetime): 5081 pass 5082 elif isinstance(value, datetime.date): 5083 value = datetime.datetime(value.year, value.month, value.day) 5084 else: 5085 return int(round(value * self.resolution)) 5086 5087 timestamp = self.get_timestamp(value) 5088 if self.resolution > 1: 5089 timestamp += (value.microsecond * .000001) 5090 timestamp *= self.resolution 5091 return int(round(timestamp)) 5092 5093 def python_value(self, value): 5094 if value is not None and isinstance(value, (int, float, long)): 5095 if self.resolution > 1: 5096 value, ticks = divmod(value, self.resolution) 5097 microseconds = int(ticks * self.ticks_to_microsecond) 5098 else: 5099 microseconds = 0 5100 5101 if self.utc: 5102 value = datetime.datetime.utcfromtimestamp(value) 5103 else: 5104 value = datetime.datetime.fromtimestamp(value) 5105 5106 if microseconds: 5107 value = value.replace(microsecond=microseconds) 5108 5109 return value 5110 5111 def from_timestamp(self): 5112 expr = ((self / Value(self.resolution, converter=False)) 5113 if self.resolution > 1 else self) 5114 return self.model._meta.database.from_timestamp(expr) 5115 5116 year = property(_timestamp_date_part('year')) 5117 month = property(_timestamp_date_part('month')) 5118 day = property(_timestamp_date_part('day')) 5119 hour = property(_timestamp_date_part('hour')) 5120 minute = property(_timestamp_date_part('minute')) 5121 second = property(_timestamp_date_part('second')) 5122 5123 5124class IPField(BigIntegerField): 5125 def db_value(self, val): 5126 if val is not None: 5127 return struct.unpack('!I', socket.inet_aton(val))[0] 5128 5129 def python_value(self, val): 5130 if val is not None: 5131 return socket.inet_ntoa(struct.pack('!I', val)) 5132 5133 5134class BooleanField(Field): 5135 field_type = 'BOOL' 5136 adapt = bool 5137 5138 5139class BareField(Field): 5140 def __init__(self, adapt=None, *args, **kwargs): 5141 super(BareField, self).__init__(*args, **kwargs) 5142 if adapt is not None: 5143 self.adapt = adapt 5144 5145 def ddl_datatype(self, ctx): 5146 return 5147 5148 5149class ForeignKeyField(Field): 5150 accessor_class = ForeignKeyAccessor 5151 5152 def __init__(self, model, field=None, backref=None, on_delete=None, 5153 on_update=None, deferrable=None, _deferred=None, 5154 rel_model=None, to_field=None, object_id_name=None, 5155 lazy_load=True, constraint_name=None, related_name=None, 5156 *args, **kwargs): 5157 kwargs.setdefault('index', True) 5158 5159 super(ForeignKeyField, self).__init__(*args, **kwargs) 5160 5161 if rel_model is not None: 5162 __deprecated__('"rel_model" has been deprecated in favor of ' 5163 '"model" for ForeignKeyField objects.') 5164 model = rel_model 5165 if to_field is not None: 5166 __deprecated__('"to_field" has been deprecated in favor of ' 5167 '"field" for ForeignKeyField objects.') 5168 field = to_field 5169 if related_name is not None: 5170 __deprecated__('"related_name" has been deprecated in favor of ' 5171 '"backref" for Field objects.') 5172 backref = related_name 5173 5174 self._is_self_reference = model == 'self' 5175 self.rel_model = model 5176 self.rel_field = field 5177 self.declared_backref = backref 5178 self.backref = None 5179 self.on_delete = on_delete 5180 self.on_update = on_update 5181 self.deferrable = deferrable 5182 self.deferred = _deferred 5183 self.object_id_name = object_id_name 5184 self.lazy_load = lazy_load 5185 self.constraint_name = constraint_name 5186 5187 @property 5188 def field_type(self): 5189 if not isinstance(self.rel_field, AutoField): 5190 return self.rel_field.field_type 5191 elif isinstance(self.rel_field, BigAutoField): 5192 return BigIntegerField.field_type 5193 return IntegerField.field_type 5194 5195 def get_modifiers(self): 5196 if not isinstance(self.rel_field, AutoField): 5197 return self.rel_field.get_modifiers() 5198 return super(ForeignKeyField, self).get_modifiers() 5199 5200 def adapt(self, value): 5201 return self.rel_field.adapt(value) 5202 5203 def db_value(self, value): 5204 if isinstance(value, self.rel_model): 5205 value = getattr(value, self.rel_field.name) 5206 return self.rel_field.db_value(value) 5207 5208 def python_value(self, value): 5209 if isinstance(value, self.rel_model): 5210 return value 5211 return self.rel_field.python_value(value) 5212 5213 def bind(self, model, name, set_attribute=True): 5214 if not self.column_name: 5215 self.column_name = name if name.endswith('_id') else name + '_id' 5216 if not self.object_id_name: 5217 self.object_id_name = self.column_name 5218 if self.object_id_name == name: 5219 self.object_id_name += '_id' 5220 elif self.object_id_name == name: 5221 raise ValueError('ForeignKeyField "%s"."%s" specifies an ' 5222 'object_id_name that conflicts with its field ' 5223 'name.' % (model._meta.name, name)) 5224 if self._is_self_reference: 5225 self.rel_model = model 5226 if isinstance(self.rel_field, basestring): 5227 self.rel_field = getattr(self.rel_model, self.rel_field) 5228 elif self.rel_field is None: 5229 self.rel_field = self.rel_model._meta.primary_key 5230 5231 # Bind field before assigning backref, so field is bound when 5232 # calling declared_backref() (if callable). 5233 super(ForeignKeyField, self).bind(model, name, set_attribute) 5234 self.safe_name = self.object_id_name 5235 5236 if callable_(self.declared_backref): 5237 self.backref = self.declared_backref(self) 5238 else: 5239 self.backref, self.declared_backref = self.declared_backref, None 5240 if not self.backref: 5241 self.backref = '%s_set' % model._meta.name 5242 5243 if set_attribute: 5244 setattr(model, self.object_id_name, ObjectIdAccessor(self)) 5245 if self.backref not in '!+': 5246 setattr(self.rel_model, self.backref, BackrefAccessor(self)) 5247 5248 def foreign_key_constraint(self): 5249 parts = [] 5250 if self.constraint_name: 5251 parts.extend((SQL('CONSTRAINT'), Entity(self.constraint_name))) 5252 parts.extend([ 5253 SQL('FOREIGN KEY'), 5254 EnclosedNodeList((self,)), 5255 SQL('REFERENCES'), 5256 self.rel_model, 5257 EnclosedNodeList((self.rel_field,))]) 5258 if self.on_delete: 5259 parts.append(SQL('ON DELETE %s' % self.on_delete)) 5260 if self.on_update: 5261 parts.append(SQL('ON UPDATE %s' % self.on_update)) 5262 if self.deferrable: 5263 parts.append(SQL('DEFERRABLE %s' % self.deferrable)) 5264 return NodeList(parts) 5265 5266 def __getattr__(self, attr): 5267 if attr.startswith('__'): 5268 # Prevent recursion error when deep-copying. 5269 raise AttributeError('Cannot look-up non-existant "__" methods.') 5270 if attr in self.rel_model._meta.fields: 5271 return self.rel_model._meta.fields[attr] 5272 raise AttributeError('Foreign-key has no attribute %s, nor is it a ' 5273 'valid field on the related model.' % attr) 5274 5275 5276class DeferredForeignKey(Field): 5277 _unresolved = set() 5278 5279 def __init__(self, rel_model_name, **kwargs): 5280 self.field_kwargs = kwargs 5281 self.rel_model_name = rel_model_name.lower() 5282 DeferredForeignKey._unresolved.add(self) 5283 super(DeferredForeignKey, self).__init__( 5284 column_name=kwargs.get('column_name'), 5285 null=kwargs.get('null')) 5286 5287 __hash__ = object.__hash__ 5288 5289 def __deepcopy__(self, memo=None): 5290 return DeferredForeignKey(self.rel_model_name, **self.field_kwargs) 5291 5292 def set_model(self, rel_model): 5293 field = ForeignKeyField(rel_model, _deferred=True, **self.field_kwargs) 5294 self.model._meta.add_field(self.name, field) 5295 5296 @staticmethod 5297 def resolve(model_cls): 5298 unresolved = sorted(DeferredForeignKey._unresolved, 5299 key=operator.attrgetter('_order')) 5300 for dr in unresolved: 5301 if dr.rel_model_name == model_cls.__name__.lower(): 5302 dr.set_model(model_cls) 5303 DeferredForeignKey._unresolved.discard(dr) 5304 5305 5306class DeferredThroughModel(object): 5307 def __init__(self): 5308 self._refs = [] 5309 5310 def set_field(self, model, field, name): 5311 self._refs.append((model, field, name)) 5312 5313 def set_model(self, through_model): 5314 for src_model, m2mfield, name in self._refs: 5315 m2mfield.through_model = through_model 5316 src_model._meta.add_field(name, m2mfield) 5317 5318 5319class MetaField(Field): 5320 column_name = default = model = name = None 5321 primary_key = False 5322 5323 5324class ManyToManyFieldAccessor(FieldAccessor): 5325 def __init__(self, model, field, name): 5326 super(ManyToManyFieldAccessor, self).__init__(model, field, name) 5327 self.model = field.model 5328 self.rel_model = field.rel_model 5329 self.through_model = field.through_model 5330 src_fks = self.through_model._meta.model_refs[self.model] 5331 dest_fks = self.through_model._meta.model_refs[self.rel_model] 5332 if not src_fks: 5333 raise ValueError('Cannot find foreign-key to "%s" on "%s" model.' % 5334 (self.model, self.through_model)) 5335 elif not dest_fks: 5336 raise ValueError('Cannot find foreign-key to "%s" on "%s" model.' % 5337 (self.rel_model, self.through_model)) 5338 self.src_fk = src_fks[0] 5339 self.dest_fk = dest_fks[0] 5340 5341 def __get__(self, instance, instance_type=None, force_query=False): 5342 if instance is not None: 5343 if not force_query and self.src_fk.backref != '+': 5344 backref = getattr(instance, self.src_fk.backref) 5345 if isinstance(backref, list): 5346 return [getattr(obj, self.dest_fk.name) for obj in backref] 5347 5348 src_id = getattr(instance, self.src_fk.rel_field.name) 5349 return (ManyToManyQuery(instance, self, self.rel_model) 5350 .join(self.through_model) 5351 .join(self.model) 5352 .where(self.src_fk == src_id)) 5353 5354 return self.field 5355 5356 def __set__(self, instance, value): 5357 query = self.__get__(instance, force_query=True) 5358 query.add(value, clear_existing=True) 5359 5360 5361class ManyToManyField(MetaField): 5362 accessor_class = ManyToManyFieldAccessor 5363 5364 def __init__(self, model, backref=None, through_model=None, on_delete=None, 5365 on_update=None, _is_backref=False): 5366 if through_model is not None: 5367 if not (isinstance(through_model, DeferredThroughModel) or 5368 is_model(through_model)): 5369 raise TypeError('Unexpected value for through_model. Expected ' 5370 'Model or DeferredThroughModel.') 5371 if not _is_backref and (on_delete is not None or on_update is not None): 5372 raise ValueError('Cannot specify on_delete or on_update when ' 5373 'through_model is specified.') 5374 self.rel_model = model 5375 self.backref = backref 5376 self._through_model = through_model 5377 self._on_delete = on_delete 5378 self._on_update = on_update 5379 self._is_backref = _is_backref 5380 5381 def _get_descriptor(self): 5382 return ManyToManyFieldAccessor(self) 5383 5384 def bind(self, model, name, set_attribute=True): 5385 if isinstance(self._through_model, DeferredThroughModel): 5386 self._through_model.set_field(model, self, name) 5387 return 5388 5389 super(ManyToManyField, self).bind(model, name, set_attribute) 5390 5391 if not self._is_backref: 5392 many_to_many_field = ManyToManyField( 5393 self.model, 5394 backref=name, 5395 through_model=self.through_model, 5396 on_delete=self._on_delete, 5397 on_update=self._on_update, 5398 _is_backref=True) 5399 self.backref = self.backref or model._meta.name + 's' 5400 self.rel_model._meta.add_field(self.backref, many_to_many_field) 5401 5402 def get_models(self): 5403 return [model for _, model in sorted(( 5404 (self._is_backref, self.model), 5405 (not self._is_backref, self.rel_model)))] 5406 5407 @property 5408 def through_model(self): 5409 if self._through_model is None: 5410 self._through_model = self._create_through_model() 5411 return self._through_model 5412 5413 @through_model.setter 5414 def through_model(self, value): 5415 self._through_model = value 5416 5417 def _create_through_model(self): 5418 lhs, rhs = self.get_models() 5419 tables = [model._meta.table_name for model in (lhs, rhs)] 5420 5421 class Meta: 5422 database = self.model._meta.database 5423 schema = self.model._meta.schema 5424 table_name = '%s_%s_through' % tuple(tables) 5425 indexes = ( 5426 ((lhs._meta.name, rhs._meta.name), 5427 True),) 5428 5429 params = {'on_delete': self._on_delete, 'on_update': self._on_update} 5430 attrs = { 5431 lhs._meta.name: ForeignKeyField(lhs, **params), 5432 rhs._meta.name: ForeignKeyField(rhs, **params), 5433 'Meta': Meta} 5434 5435 klass_name = '%s%sThrough' % (lhs.__name__, rhs.__name__) 5436 return type(klass_name, (Model,), attrs) 5437 5438 def get_through_model(self): 5439 # XXX: Deprecated. Just use the "through_model" property. 5440 return self.through_model 5441 5442 5443class VirtualField(MetaField): 5444 field_class = None 5445 5446 def __init__(self, field_class=None, *args, **kwargs): 5447 Field = field_class if field_class is not None else self.field_class 5448 self.field_instance = Field() if Field is not None else None 5449 super(VirtualField, self).__init__(*args, **kwargs) 5450 5451 def db_value(self, value): 5452 if self.field_instance is not None: 5453 return self.field_instance.db_value(value) 5454 return value 5455 5456 def python_value(self, value): 5457 if self.field_instance is not None: 5458 return self.field_instance.python_value(value) 5459 return value 5460 5461 def bind(self, model, name, set_attribute=True): 5462 self.model = model 5463 self.column_name = self.name = self.safe_name = name 5464 setattr(model, name, self.accessor_class(model, self, name)) 5465 5466 5467class CompositeKey(MetaField): 5468 sequence = None 5469 5470 def __init__(self, *field_names): 5471 self.field_names = field_names 5472 self._safe_field_names = None 5473 5474 @property 5475 def safe_field_names(self): 5476 if self._safe_field_names is None: 5477 if self.model is None: 5478 return self.field_names 5479 5480 self._safe_field_names = [self.model._meta.fields[f].safe_name 5481 for f in self.field_names] 5482 return self._safe_field_names 5483 5484 def __get__(self, instance, instance_type=None): 5485 if instance is not None: 5486 return tuple([getattr(instance, f) for f in self.safe_field_names]) 5487 return self 5488 5489 def __set__(self, instance, value): 5490 if not isinstance(value, (list, tuple)): 5491 raise TypeError('A list or tuple must be used to set the value of ' 5492 'a composite primary key.') 5493 if len(value) != len(self.field_names): 5494 raise ValueError('The length of the value must equal the number ' 5495 'of columns of the composite primary key.') 5496 for idx, field_value in enumerate(value): 5497 setattr(instance, self.field_names[idx], field_value) 5498 5499 def __eq__(self, other): 5500 expressions = [(self.model._meta.fields[field] == value) 5501 for field, value in zip(self.field_names, other)] 5502 return reduce(operator.and_, expressions) 5503 5504 def __ne__(self, other): 5505 return ~(self == other) 5506 5507 def __hash__(self): 5508 return hash((self.model.__name__, self.field_names)) 5509 5510 def __sql__(self, ctx): 5511 # If the composite PK is being selected, do not use parens. Elsewhere, 5512 # such as in an expression, we want to use parentheses and treat it as 5513 # a row value. 5514 parens = ctx.scope != SCOPE_SOURCE 5515 return ctx.sql(NodeList([self.model._meta.fields[field] 5516 for field in self.field_names], ', ', parens)) 5517 5518 def bind(self, model, name, set_attribute=True): 5519 self.model = model 5520 self.column_name = self.name = self.safe_name = name 5521 setattr(model, self.name, self) 5522 5523 5524class _SortedFieldList(object): 5525 __slots__ = ('_keys', '_items') 5526 5527 def __init__(self): 5528 self._keys = [] 5529 self._items = [] 5530 5531 def __getitem__(self, i): 5532 return self._items[i] 5533 5534 def __iter__(self): 5535 return iter(self._items) 5536 5537 def __contains__(self, item): 5538 k = item._sort_key 5539 i = bisect_left(self._keys, k) 5540 j = bisect_right(self._keys, k) 5541 return item in self._items[i:j] 5542 5543 def index(self, field): 5544 return self._keys.index(field._sort_key) 5545 5546 def insert(self, item): 5547 k = item._sort_key 5548 i = bisect_left(self._keys, k) 5549 self._keys.insert(i, k) 5550 self._items.insert(i, item) 5551 5552 def remove(self, item): 5553 idx = self.index(item) 5554 del self._items[idx] 5555 del self._keys[idx] 5556 5557 5558# MODELS 5559 5560 5561class SchemaManager(object): 5562 def __init__(self, model, database=None, **context_options): 5563 self.model = model 5564 self._database = database 5565 context_options.setdefault('scope', SCOPE_VALUES) 5566 self.context_options = context_options 5567 5568 @property 5569 def database(self): 5570 db = self._database or self.model._meta.database 5571 if db is None: 5572 raise ImproperlyConfigured('database attribute does not appear to ' 5573 'be set on the model: %s' % self.model) 5574 return db 5575 5576 @database.setter 5577 def database(self, value): 5578 self._database = value 5579 5580 def _create_context(self): 5581 return self.database.get_sql_context(**self.context_options) 5582 5583 def _create_table(self, safe=True, **options): 5584 is_temp = options.pop('temporary', False) 5585 ctx = self._create_context() 5586 ctx.literal('CREATE TEMPORARY TABLE ' if is_temp else 'CREATE TABLE ') 5587 if safe: 5588 ctx.literal('IF NOT EXISTS ') 5589 ctx.sql(self.model).literal(' ') 5590 5591 columns = [] 5592 constraints = [] 5593 meta = self.model._meta 5594 if meta.composite_key: 5595 pk_columns = [meta.fields[field_name].column 5596 for field_name in meta.primary_key.field_names] 5597 constraints.append(NodeList((SQL('PRIMARY KEY'), 5598 EnclosedNodeList(pk_columns)))) 5599 5600 for field in meta.sorted_fields: 5601 columns.append(field.ddl(ctx)) 5602 if isinstance(field, ForeignKeyField) and not field.deferred: 5603 constraints.append(field.foreign_key_constraint()) 5604 5605 if meta.constraints: 5606 constraints.extend(meta.constraints) 5607 5608 constraints.extend(self._create_table_option_sql(options)) 5609 ctx.sql(EnclosedNodeList(columns + constraints)) 5610 5611 if meta.table_settings is not None: 5612 table_settings = ensure_tuple(meta.table_settings) 5613 for setting in table_settings: 5614 if not isinstance(setting, basestring): 5615 raise ValueError('table_settings must be strings') 5616 ctx.literal(' ').literal(setting) 5617 5618 if meta.without_rowid: 5619 ctx.literal(' WITHOUT ROWID') 5620 return ctx 5621 5622 def _create_table_option_sql(self, options): 5623 accum = [] 5624 options = merge_dict(self.model._meta.options or {}, options) 5625 if not options: 5626 return accum 5627 5628 for key, value in sorted(options.items()): 5629 if not isinstance(value, Node): 5630 if is_model(value): 5631 value = value._meta.table 5632 else: 5633 value = SQL(str(value)) 5634 accum.append(NodeList((SQL(key), value), glue='=')) 5635 return accum 5636 5637 def create_table(self, safe=True, **options): 5638 self.database.execute(self._create_table(safe=safe, **options)) 5639 5640 def _create_table_as(self, table_name, query, safe=True, **meta): 5641 ctx = (self._create_context() 5642 .literal('CREATE TEMPORARY TABLE ' 5643 if meta.get('temporary') else 'CREATE TABLE ')) 5644 if safe: 5645 ctx.literal('IF NOT EXISTS ') 5646 return (ctx 5647 .sql(Entity(table_name)) 5648 .literal(' AS ') 5649 .sql(query)) 5650 5651 def create_table_as(self, table_name, query, safe=True, **meta): 5652 ctx = self._create_table_as(table_name, query, safe=safe, **meta) 5653 self.database.execute(ctx) 5654 5655 def _drop_table(self, safe=True, **options): 5656 ctx = (self._create_context() 5657 .literal('DROP TABLE IF EXISTS ' if safe else 'DROP TABLE ') 5658 .sql(self.model)) 5659 if options.get('cascade'): 5660 ctx = ctx.literal(' CASCADE') 5661 elif options.get('restrict'): 5662 ctx = ctx.literal(' RESTRICT') 5663 return ctx 5664 5665 def drop_table(self, safe=True, **options): 5666 self.database.execute(self._drop_table(safe=safe, **options)) 5667 5668 def _truncate_table(self, restart_identity=False, cascade=False): 5669 db = self.database 5670 if not db.truncate_table: 5671 return (self._create_context() 5672 .literal('DELETE FROM ').sql(self.model)) 5673 5674 ctx = self._create_context().literal('TRUNCATE TABLE ').sql(self.model) 5675 if restart_identity: 5676 ctx = ctx.literal(' RESTART IDENTITY') 5677 if cascade: 5678 ctx = ctx.literal(' CASCADE') 5679 return ctx 5680 5681 def truncate_table(self, restart_identity=False, cascade=False): 5682 self.database.execute(self._truncate_table(restart_identity, cascade)) 5683 5684 def _create_indexes(self, safe=True): 5685 return [self._create_index(index, safe) 5686 for index in self.model._meta.fields_to_index()] 5687 5688 def _create_index(self, index, safe=True): 5689 if isinstance(index, Index): 5690 if not self.database.safe_create_index: 5691 index = index.safe(False) 5692 elif index._safe != safe: 5693 index = index.safe(safe) 5694 return self._create_context().sql(index) 5695 5696 def create_indexes(self, safe=True): 5697 for query in self._create_indexes(safe=safe): 5698 self.database.execute(query) 5699 5700 def _drop_indexes(self, safe=True): 5701 return [self._drop_index(index, safe) 5702 for index in self.model._meta.fields_to_index() 5703 if isinstance(index, Index)] 5704 5705 def _drop_index(self, index, safe): 5706 statement = 'DROP INDEX ' 5707 if safe and self.database.safe_drop_index: 5708 statement += 'IF EXISTS ' 5709 if isinstance(index._table, Table) and index._table._schema: 5710 index_name = Entity(index._table._schema, index._name) 5711 else: 5712 index_name = Entity(index._name) 5713 return (self 5714 ._create_context() 5715 .literal(statement) 5716 .sql(index_name)) 5717 5718 def drop_indexes(self, safe=True): 5719 for query in self._drop_indexes(safe=safe): 5720 self.database.execute(query) 5721 5722 def _check_sequences(self, field): 5723 if not field.sequence or not self.database.sequences: 5724 raise ValueError('Sequences are either not supported, or are not ' 5725 'defined for "%s".' % field.name) 5726 5727 def _sequence_for_field(self, field): 5728 if field.model._meta.schema: 5729 return Entity(field.model._meta.schema, field.sequence) 5730 else: 5731 return Entity(field.sequence) 5732 5733 def _create_sequence(self, field): 5734 self._check_sequences(field) 5735 if not self.database.sequence_exists(field.sequence): 5736 return (self 5737 ._create_context() 5738 .literal('CREATE SEQUENCE ') 5739 .sql(self._sequence_for_field(field))) 5740 5741 def create_sequence(self, field): 5742 seq_ctx = self._create_sequence(field) 5743 if seq_ctx is not None: 5744 self.database.execute(seq_ctx) 5745 5746 def _drop_sequence(self, field): 5747 self._check_sequences(field) 5748 if self.database.sequence_exists(field.sequence): 5749 return (self 5750 ._create_context() 5751 .literal('DROP SEQUENCE ') 5752 .sql(self._sequence_for_field(field))) 5753 5754 def drop_sequence(self, field): 5755 seq_ctx = self._drop_sequence(field) 5756 if seq_ctx is not None: 5757 self.database.execute(seq_ctx) 5758 5759 def _create_foreign_key(self, field): 5760 name = 'fk_%s_%s_refs_%s' % (field.model._meta.table_name, 5761 field.column_name, 5762 field.rel_model._meta.table_name) 5763 return (self 5764 ._create_context() 5765 .literal('ALTER TABLE ') 5766 .sql(field.model) 5767 .literal(' ADD CONSTRAINT ') 5768 .sql(Entity(_truncate_constraint_name(name))) 5769 .literal(' ') 5770 .sql(field.foreign_key_constraint())) 5771 5772 def create_foreign_key(self, field): 5773 self.database.execute(self._create_foreign_key(field)) 5774 5775 def create_sequences(self): 5776 if self.database.sequences: 5777 for field in self.model._meta.sorted_fields: 5778 if field.sequence: 5779 self.create_sequence(field) 5780 5781 def create_all(self, safe=True, **table_options): 5782 self.create_sequences() 5783 self.create_table(safe, **table_options) 5784 self.create_indexes(safe=safe) 5785 5786 def drop_sequences(self): 5787 if self.database.sequences: 5788 for field in self.model._meta.sorted_fields: 5789 if field.sequence: 5790 self.drop_sequence(field) 5791 5792 def drop_all(self, safe=True, drop_sequences=True, **options): 5793 self.drop_table(safe, **options) 5794 if drop_sequences: 5795 self.drop_sequences() 5796 5797 5798class Metadata(object): 5799 def __init__(self, model, database=None, table_name=None, indexes=None, 5800 primary_key=None, constraints=None, schema=None, 5801 only_save_dirty=False, depends_on=None, options=None, 5802 db_table=None, table_function=None, table_settings=None, 5803 without_rowid=False, temporary=False, legacy_table_names=True, 5804 **kwargs): 5805 if db_table is not None: 5806 __deprecated__('"db_table" has been deprecated in favor of ' 5807 '"table_name" for Models.') 5808 table_name = db_table 5809 self.model = model 5810 self.database = database 5811 5812 self.fields = {} 5813 self.columns = {} 5814 self.combined = {} 5815 5816 self._sorted_field_list = _SortedFieldList() 5817 self.sorted_fields = [] 5818 self.sorted_field_names = [] 5819 5820 self.defaults = {} 5821 self._default_by_name = {} 5822 self._default_dict = {} 5823 self._default_callables = {} 5824 self._default_callable_list = [] 5825 5826 self.name = model.__name__.lower() 5827 self.table_function = table_function 5828 self.legacy_table_names = legacy_table_names 5829 if not table_name: 5830 table_name = (self.table_function(model) 5831 if self.table_function 5832 else self.make_table_name()) 5833 self.table_name = table_name 5834 self._table = None 5835 5836 self.indexes = list(indexes) if indexes else [] 5837 self.constraints = constraints 5838 self._schema = schema 5839 self.primary_key = primary_key 5840 self.composite_key = self.auto_increment = None 5841 self.only_save_dirty = only_save_dirty 5842 self.depends_on = depends_on 5843 self.table_settings = table_settings 5844 self.without_rowid = without_rowid 5845 self.temporary = temporary 5846 5847 self.refs = {} 5848 self.backrefs = {} 5849 self.model_refs = collections.defaultdict(list) 5850 self.model_backrefs = collections.defaultdict(list) 5851 self.manytomany = {} 5852 5853 self.options = options or {} 5854 for key, value in kwargs.items(): 5855 setattr(self, key, value) 5856 self._additional_keys = set(kwargs.keys()) 5857 5858 # Allow objects to register hooks that are called if the model is bound 5859 # to a different database. For example, BlobField uses a different 5860 # Python data-type depending on the db driver / python version. When 5861 # the database changes, we need to update any BlobField so they can use 5862 # the appropriate data-type. 5863 self._db_hooks = [] 5864 5865 def make_table_name(self): 5866 if self.legacy_table_names: 5867 return re.sub(r'[^\w]+', '_', self.name) 5868 return make_snake_case(self.model.__name__) 5869 5870 def model_graph(self, refs=True, backrefs=True, depth_first=True): 5871 if not refs and not backrefs: 5872 raise ValueError('One of `refs` or `backrefs` must be True.') 5873 5874 accum = [(None, self.model, None)] 5875 seen = set() 5876 queue = collections.deque((self,)) 5877 method = queue.pop if depth_first else queue.popleft 5878 5879 while queue: 5880 curr = method() 5881 if curr in seen: continue 5882 seen.add(curr) 5883 5884 if refs: 5885 for fk, model in curr.refs.items(): 5886 accum.append((fk, model, False)) 5887 queue.append(model._meta) 5888 if backrefs: 5889 for fk, model in curr.backrefs.items(): 5890 accum.append((fk, model, True)) 5891 queue.append(model._meta) 5892 5893 return accum 5894 5895 def add_ref(self, field): 5896 rel = field.rel_model 5897 self.refs[field] = rel 5898 self.model_refs[rel].append(field) 5899 rel._meta.backrefs[field] = self.model 5900 rel._meta.model_backrefs[self.model].append(field) 5901 5902 def remove_ref(self, field): 5903 rel = field.rel_model 5904 del self.refs[field] 5905 self.model_refs[rel].remove(field) 5906 del rel._meta.backrefs[field] 5907 rel._meta.model_backrefs[self.model].remove(field) 5908 5909 def add_manytomany(self, field): 5910 self.manytomany[field.name] = field 5911 5912 def remove_manytomany(self, field): 5913 del self.manytomany[field.name] 5914 5915 @property 5916 def table(self): 5917 if self._table is None: 5918 self._table = Table( 5919 self.table_name, 5920 [field.column_name for field in self.sorted_fields], 5921 schema=self.schema, 5922 _model=self.model, 5923 _database=self.database) 5924 return self._table 5925 5926 @table.setter 5927 def table(self, value): 5928 raise AttributeError('Cannot set the "table".') 5929 5930 @table.deleter 5931 def table(self): 5932 self._table = None 5933 5934 @property 5935 def schema(self): 5936 return self._schema 5937 5938 @schema.setter 5939 def schema(self, value): 5940 self._schema = value 5941 del self.table 5942 5943 @property 5944 def entity(self): 5945 if self._schema: 5946 return Entity(self._schema, self.table_name) 5947 else: 5948 return Entity(self.table_name) 5949 5950 def _update_sorted_fields(self): 5951 self.sorted_fields = list(self._sorted_field_list) 5952 self.sorted_field_names = [f.name for f in self.sorted_fields] 5953 5954 def get_rel_for_model(self, model): 5955 if isinstance(model, ModelAlias): 5956 model = model.model 5957 forwardrefs = self.model_refs.get(model, []) 5958 backrefs = self.model_backrefs.get(model, []) 5959 return (forwardrefs, backrefs) 5960 5961 def add_field(self, field_name, field, set_attribute=True): 5962 if field_name in self.fields: 5963 self.remove_field(field_name) 5964 elif field_name in self.manytomany: 5965 self.remove_manytomany(self.manytomany[field_name]) 5966 5967 if not isinstance(field, MetaField): 5968 del self.table 5969 field.bind(self.model, field_name, set_attribute) 5970 self.fields[field.name] = field 5971 self.columns[field.column_name] = field 5972 self.combined[field.name] = field 5973 self.combined[field.column_name] = field 5974 5975 self._sorted_field_list.insert(field) 5976 self._update_sorted_fields() 5977 5978 if field.default is not None: 5979 # This optimization helps speed up model instance construction. 5980 self.defaults[field] = field.default 5981 if callable_(field.default): 5982 self._default_callables[field] = field.default 5983 self._default_callable_list.append((field.name, 5984 field.default)) 5985 else: 5986 self._default_dict[field] = field.default 5987 self._default_by_name[field.name] = field.default 5988 else: 5989 field.bind(self.model, field_name, set_attribute) 5990 5991 if isinstance(field, ForeignKeyField): 5992 self.add_ref(field) 5993 elif isinstance(field, ManyToManyField) and field.name: 5994 self.add_manytomany(field) 5995 5996 def remove_field(self, field_name): 5997 if field_name not in self.fields: 5998 return 5999 6000 del self.table 6001 original = self.fields.pop(field_name) 6002 del self.columns[original.column_name] 6003 del self.combined[field_name] 6004 try: 6005 del self.combined[original.column_name] 6006 except KeyError: 6007 pass 6008 self._sorted_field_list.remove(original) 6009 self._update_sorted_fields() 6010 6011 if original.default is not None: 6012 del self.defaults[original] 6013 if self._default_callables.pop(original, None): 6014 for i, (name, _) in enumerate(self._default_callable_list): 6015 if name == field_name: 6016 self._default_callable_list.pop(i) 6017 break 6018 else: 6019 self._default_dict.pop(original, None) 6020 self._default_by_name.pop(original.name, None) 6021 6022 if isinstance(original, ForeignKeyField): 6023 self.remove_ref(original) 6024 6025 def set_primary_key(self, name, field): 6026 self.composite_key = isinstance(field, CompositeKey) 6027 self.add_field(name, field) 6028 self.primary_key = field 6029 self.auto_increment = ( 6030 field.auto_increment or 6031 bool(field.sequence)) 6032 6033 def get_primary_keys(self): 6034 if self.composite_key: 6035 return tuple([self.fields[field_name] 6036 for field_name in self.primary_key.field_names]) 6037 else: 6038 return (self.primary_key,) if self.primary_key is not False else () 6039 6040 def get_default_dict(self): 6041 dd = self._default_by_name.copy() 6042 for field_name, default in self._default_callable_list: 6043 dd[field_name] = default() 6044 return dd 6045 6046 def fields_to_index(self): 6047 indexes = [] 6048 for f in self.sorted_fields: 6049 if f.primary_key: 6050 continue 6051 if f.index or f.unique: 6052 indexes.append(ModelIndex(self.model, (f,), unique=f.unique, 6053 using=f.index_type)) 6054 6055 for index_obj in self.indexes: 6056 if isinstance(index_obj, Node): 6057 indexes.append(index_obj) 6058 elif isinstance(index_obj, (list, tuple)): 6059 index_parts, unique = index_obj 6060 fields = [] 6061 for part in index_parts: 6062 if isinstance(part, basestring): 6063 fields.append(self.combined[part]) 6064 elif isinstance(part, Node): 6065 fields.append(part) 6066 else: 6067 raise ValueError('Expected either a field name or a ' 6068 'subclass of Node. Got: %s' % part) 6069 indexes.append(ModelIndex(self.model, fields, unique=unique)) 6070 6071 return indexes 6072 6073 def set_database(self, database): 6074 self.database = database 6075 self.model._schema._database = database 6076 del self.table 6077 6078 # Apply any hooks that have been registered. 6079 for hook in self._db_hooks: 6080 hook(database) 6081 6082 def set_table_name(self, table_name): 6083 self.table_name = table_name 6084 del self.table 6085 6086 6087class SubclassAwareMetadata(Metadata): 6088 models = [] 6089 6090 def __init__(self, model, *args, **kwargs): 6091 super(SubclassAwareMetadata, self).__init__(model, *args, **kwargs) 6092 self.models.append(model) 6093 6094 def map_models(self, fn): 6095 for model in self.models: 6096 fn(model) 6097 6098 6099class DoesNotExist(Exception): pass 6100 6101 6102class ModelBase(type): 6103 inheritable = set(['constraints', 'database', 'indexes', 'primary_key', 6104 'options', 'schema', 'table_function', 'temporary', 6105 'only_save_dirty', 'legacy_table_names', 6106 'table_settings']) 6107 6108 def __new__(cls, name, bases, attrs): 6109 if name == MODEL_BASE or bases[0].__name__ == MODEL_BASE: 6110 return super(ModelBase, cls).__new__(cls, name, bases, attrs) 6111 6112 meta_options = {} 6113 meta = attrs.pop('Meta', None) 6114 if meta: 6115 for k, v in meta.__dict__.items(): 6116 if not k.startswith('_'): 6117 meta_options[k] = v 6118 6119 pk = getattr(meta, 'primary_key', None) 6120 pk_name = parent_pk = None 6121 6122 # Inherit any field descriptors by deep copying the underlying field 6123 # into the attrs of the new model, additionally see if the bases define 6124 # inheritable model options and swipe them. 6125 for b in bases: 6126 if not hasattr(b, '_meta'): 6127 continue 6128 6129 base_meta = b._meta 6130 if parent_pk is None: 6131 parent_pk = deepcopy(base_meta.primary_key) 6132 all_inheritable = cls.inheritable | base_meta._additional_keys 6133 for k in base_meta.__dict__: 6134 if k in all_inheritable and k not in meta_options: 6135 meta_options[k] = base_meta.__dict__[k] 6136 meta_options.setdefault('schema', base_meta.schema) 6137 6138 for (k, v) in b.__dict__.items(): 6139 if k in attrs: continue 6140 6141 if isinstance(v, FieldAccessor) and not v.field.primary_key: 6142 attrs[k] = deepcopy(v.field) 6143 6144 sopts = meta_options.pop('schema_options', None) or {} 6145 Meta = meta_options.get('model_metadata_class', Metadata) 6146 Schema = meta_options.get('schema_manager_class', SchemaManager) 6147 6148 # Construct the new class. 6149 cls = super(ModelBase, cls).__new__(cls, name, bases, attrs) 6150 cls.__data__ = cls.__rel__ = None 6151 6152 cls._meta = Meta(cls, **meta_options) 6153 cls._schema = Schema(cls, **sopts) 6154 6155 fields = [] 6156 for key, value in cls.__dict__.items(): 6157 if isinstance(value, Field): 6158 if value.primary_key and pk: 6159 raise ValueError('over-determined primary key %s.' % name) 6160 elif value.primary_key: 6161 pk, pk_name = value, key 6162 else: 6163 fields.append((key, value)) 6164 6165 if pk is None: 6166 if parent_pk is not False: 6167 pk, pk_name = ((parent_pk, parent_pk.name) 6168 if parent_pk is not None else 6169 (AutoField(), 'id')) 6170 else: 6171 pk = False 6172 elif isinstance(pk, CompositeKey): 6173 pk_name = '__composite_key__' 6174 cls._meta.composite_key = True 6175 6176 if pk is not False: 6177 cls._meta.set_primary_key(pk_name, pk) 6178 6179 for name, field in fields: 6180 cls._meta.add_field(name, field) 6181 6182 # Create a repr and error class before finalizing. 6183 if hasattr(cls, '__str__') and '__repr__' not in attrs: 6184 setattr(cls, '__repr__', lambda self: '<%s: %s>' % ( 6185 cls.__name__, self.__str__())) 6186 6187 exc_name = '%sDoesNotExist' % cls.__name__ 6188 exc_attrs = {'__module__': cls.__module__} 6189 exception_class = type(exc_name, (DoesNotExist,), exc_attrs) 6190 cls.DoesNotExist = exception_class 6191 6192 # Call validation hook, allowing additional model validation. 6193 cls.validate_model() 6194 DeferredForeignKey.resolve(cls) 6195 return cls 6196 6197 def __repr__(self): 6198 return '<Model: %s>' % self.__name__ 6199 6200 def __iter__(self): 6201 return iter(self.select()) 6202 6203 def __getitem__(self, key): 6204 return self.get_by_id(key) 6205 6206 def __setitem__(self, key, value): 6207 self.set_by_id(key, value) 6208 6209 def __delitem__(self, key): 6210 self.delete_by_id(key) 6211 6212 def __contains__(self, key): 6213 try: 6214 self.get_by_id(key) 6215 except self.DoesNotExist: 6216 return False 6217 else: 6218 return True 6219 6220 def __len__(self): 6221 return self.select().count() 6222 def __bool__(self): return True 6223 __nonzero__ = __bool__ # Python 2. 6224 6225 def __sql__(self, ctx): 6226 return ctx.sql(self._meta.table) 6227 6228 6229class _BoundModelsContext(_callable_context_manager): 6230 def __init__(self, models, database, bind_refs, bind_backrefs): 6231 self.models = models 6232 self.database = database 6233 self.bind_refs = bind_refs 6234 self.bind_backrefs = bind_backrefs 6235 6236 def __enter__(self): 6237 self._orig_database = [] 6238 for model in self.models: 6239 self._orig_database.append(model._meta.database) 6240 model.bind(self.database, self.bind_refs, self.bind_backrefs, 6241 _exclude=set(self.models)) 6242 return self.models 6243 6244 def __exit__(self, exc_type, exc_val, exc_tb): 6245 for model, db in zip(self.models, self._orig_database): 6246 model.bind(db, self.bind_refs, self.bind_backrefs, 6247 _exclude=set(self.models)) 6248 6249 6250class Model(with_metaclass(ModelBase, Node)): 6251 def __init__(self, *args, **kwargs): 6252 if kwargs.pop('__no_default__', None): 6253 self.__data__ = {} 6254 else: 6255 self.__data__ = self._meta.get_default_dict() 6256 self._dirty = set(self.__data__) 6257 self.__rel__ = {} 6258 6259 for k in kwargs: 6260 setattr(self, k, kwargs[k]) 6261 6262 def __str__(self): 6263 return str(self._pk) if self._meta.primary_key is not False else 'n/a' 6264 6265 @classmethod 6266 def validate_model(cls): 6267 pass 6268 6269 @classmethod 6270 def alias(cls, alias=None): 6271 return ModelAlias(cls, alias) 6272 6273 @classmethod 6274 def select(cls, *fields): 6275 is_default = not fields 6276 if not fields: 6277 fields = cls._meta.sorted_fields 6278 return ModelSelect(cls, fields, is_default=is_default) 6279 6280 @classmethod 6281 def _normalize_data(cls, data, kwargs): 6282 normalized = {} 6283 if data: 6284 if not isinstance(data, dict): 6285 if kwargs: 6286 raise ValueError('Data cannot be mixed with keyword ' 6287 'arguments: %s' % data) 6288 return data 6289 for key in data: 6290 try: 6291 field = (key if isinstance(key, Field) 6292 else cls._meta.combined[key]) 6293 except KeyError: 6294 if not isinstance(key, Node): 6295 raise ValueError('Unrecognized field name: "%s" in %s.' 6296 % (key, data)) 6297 field = key 6298 normalized[field] = data[key] 6299 if kwargs: 6300 for key in kwargs: 6301 try: 6302 normalized[cls._meta.combined[key]] = kwargs[key] 6303 except KeyError: 6304 normalized[getattr(cls, key)] = kwargs[key] 6305 return normalized 6306 6307 @classmethod 6308 def update(cls, __data=None, **update): 6309 return ModelUpdate(cls, cls._normalize_data(__data, update)) 6310 6311 @classmethod 6312 def insert(cls, __data=None, **insert): 6313 return ModelInsert(cls, cls._normalize_data(__data, insert)) 6314 6315 @classmethod 6316 def insert_many(cls, rows, fields=None): 6317 return ModelInsert(cls, insert=rows, columns=fields) 6318 6319 @classmethod 6320 def insert_from(cls, query, fields): 6321 columns = [getattr(cls, field) if isinstance(field, basestring) 6322 else field for field in fields] 6323 return ModelInsert(cls, insert=query, columns=columns) 6324 6325 @classmethod 6326 def replace(cls, __data=None, **insert): 6327 return cls.insert(__data, **insert).on_conflict('REPLACE') 6328 6329 @classmethod 6330 def replace_many(cls, rows, fields=None): 6331 return (cls 6332 .insert_many(rows=rows, fields=fields) 6333 .on_conflict('REPLACE')) 6334 6335 @classmethod 6336 def raw(cls, sql, *params): 6337 return ModelRaw(cls, sql, params) 6338 6339 @classmethod 6340 def delete(cls): 6341 return ModelDelete(cls) 6342 6343 @classmethod 6344 def create(cls, **query): 6345 inst = cls(**query) 6346 inst.save(force_insert=True) 6347 return inst 6348 6349 @classmethod 6350 def bulk_create(cls, model_list, batch_size=None): 6351 if batch_size is not None: 6352 batches = chunked(model_list, batch_size) 6353 else: 6354 batches = [model_list] 6355 6356 field_names = list(cls._meta.sorted_field_names) 6357 if cls._meta.auto_increment: 6358 pk_name = cls._meta.primary_key.name 6359 field_names.remove(pk_name) 6360 6361 if cls._meta.database.returning_clause and \ 6362 cls._meta.primary_key is not False: 6363 pk_fields = cls._meta.get_primary_keys() 6364 else: 6365 pk_fields = None 6366 6367 fields = [cls._meta.fields[field_name] for field_name in field_names] 6368 attrs = [] 6369 for field in fields: 6370 if isinstance(field, ForeignKeyField): 6371 attrs.append(field.object_id_name) 6372 else: 6373 attrs.append(field.name) 6374 6375 for batch in batches: 6376 accum = ([getattr(model, f) for f in attrs] 6377 for model in batch) 6378 res = cls.insert_many(accum, fields=fields).execute() 6379 if pk_fields and res is not None: 6380 for row, model in zip(res, batch): 6381 for (pk_field, obj_id) in zip(pk_fields, row): 6382 setattr(model, pk_field.name, obj_id) 6383 6384 @classmethod 6385 def bulk_update(cls, model_list, fields, batch_size=None): 6386 if isinstance(cls._meta.primary_key, CompositeKey): 6387 raise ValueError('bulk_update() is not supported for models with ' 6388 'a composite primary key.') 6389 6390 # First normalize list of fields so all are field instances. 6391 fields = [cls._meta.fields[f] if isinstance(f, basestring) else f 6392 for f in fields] 6393 # Now collect list of attribute names to use for values. 6394 attrs = [field.object_id_name if isinstance(field, ForeignKeyField) 6395 else field.name for field in fields] 6396 6397 if batch_size is not None: 6398 batches = chunked(model_list, batch_size) 6399 else: 6400 batches = [model_list] 6401 6402 n = 0 6403 pk = cls._meta.primary_key 6404 6405 for batch in batches: 6406 id_list = [model._pk for model in batch] 6407 update = {} 6408 for field, attr in zip(fields, attrs): 6409 accum = [] 6410 for model in batch: 6411 value = getattr(model, attr) 6412 if not isinstance(value, Node): 6413 value = field.to_value(value) 6414 accum.append((pk.to_value(model._pk), value)) 6415 case = Case(pk, accum) 6416 update[field] = case 6417 6418 n += (cls.update(update) 6419 .where(cls._meta.primary_key.in_(id_list)) 6420 .execute()) 6421 return n 6422 6423 @classmethod 6424 def noop(cls): 6425 return NoopModelSelect(cls, ()) 6426 6427 @classmethod 6428 def get(cls, *query, **filters): 6429 sq = cls.select() 6430 if query: 6431 # Handle simple lookup using just the primary key. 6432 if len(query) == 1 and isinstance(query[0], int): 6433 sq = sq.where(cls._meta.primary_key == query[0]) 6434 else: 6435 sq = sq.where(*query) 6436 if filters: 6437 sq = sq.filter(**filters) 6438 return sq.get() 6439 6440 @classmethod 6441 def get_or_none(cls, *query, **filters): 6442 try: 6443 return cls.get(*query, **filters) 6444 except DoesNotExist: 6445 pass 6446 6447 @classmethod 6448 def get_by_id(cls, pk): 6449 return cls.get(cls._meta.primary_key == pk) 6450 6451 @classmethod 6452 def set_by_id(cls, key, value): 6453 if key is None: 6454 return cls.insert(value).execute() 6455 else: 6456 return (cls.update(value) 6457 .where(cls._meta.primary_key == key).execute()) 6458 6459 @classmethod 6460 def delete_by_id(cls, pk): 6461 return cls.delete().where(cls._meta.primary_key == pk).execute() 6462 6463 @classmethod 6464 def get_or_create(cls, **kwargs): 6465 defaults = kwargs.pop('defaults', {}) 6466 query = cls.select() 6467 for field, value in kwargs.items(): 6468 query = query.where(getattr(cls, field) == value) 6469 6470 try: 6471 return query.get(), False 6472 except cls.DoesNotExist: 6473 try: 6474 if defaults: 6475 kwargs.update(defaults) 6476 with cls._meta.database.atomic(): 6477 return cls.create(**kwargs), True 6478 except IntegrityError as exc: 6479 try: 6480 return query.get(), False 6481 except cls.DoesNotExist: 6482 raise exc 6483 6484 @classmethod 6485 def filter(cls, *dq_nodes, **filters): 6486 return cls.select().filter(*dq_nodes, **filters) 6487 6488 def get_id(self): 6489 # Using getattr(self, pk-name) could accidentally trigger a query if 6490 # the primary-key is a foreign-key. So we use the safe_name attribute, 6491 # which defaults to the field-name, but will be the object_id_name for 6492 # foreign-key fields. 6493 if self._meta.primary_key is not False: 6494 return getattr(self, self._meta.primary_key.safe_name) 6495 6496 _pk = property(get_id) 6497 6498 @_pk.setter 6499 def _pk(self, value): 6500 setattr(self, self._meta.primary_key.name, value) 6501 6502 def _pk_expr(self): 6503 return self._meta.primary_key == self._pk 6504 6505 def _prune_fields(self, field_dict, only): 6506 new_data = {} 6507 for field in only: 6508 if isinstance(field, basestring): 6509 field = self._meta.combined[field] 6510 if field.name in field_dict: 6511 new_data[field.name] = field_dict[field.name] 6512 return new_data 6513 6514 def _populate_unsaved_relations(self, field_dict): 6515 for foreign_key_field in self._meta.refs: 6516 foreign_key = foreign_key_field.name 6517 conditions = ( 6518 foreign_key in field_dict and 6519 field_dict[foreign_key] is None and 6520 self.__rel__.get(foreign_key) is not None) 6521 if conditions: 6522 setattr(self, foreign_key, getattr(self, foreign_key)) 6523 field_dict[foreign_key] = self.__data__[foreign_key] 6524 6525 def save(self, force_insert=False, only=None): 6526 field_dict = self.__data__.copy() 6527 if self._meta.primary_key is not False: 6528 pk_field = self._meta.primary_key 6529 pk_value = self._pk 6530 else: 6531 pk_field = pk_value = None 6532 if only is not None: 6533 field_dict = self._prune_fields(field_dict, only) 6534 elif self._meta.only_save_dirty and not force_insert: 6535 field_dict = self._prune_fields(field_dict, self.dirty_fields) 6536 if not field_dict: 6537 self._dirty.clear() 6538 return False 6539 6540 self._populate_unsaved_relations(field_dict) 6541 rows = 1 6542 6543 if self._meta.auto_increment and pk_value is None: 6544 field_dict.pop(pk_field.name, None) 6545 6546 if pk_value is not None and not force_insert: 6547 if self._meta.composite_key: 6548 for pk_part_name in pk_field.field_names: 6549 field_dict.pop(pk_part_name, None) 6550 else: 6551 field_dict.pop(pk_field.name, None) 6552 if not field_dict: 6553 raise ValueError('no data to save!') 6554 rows = self.update(**field_dict).where(self._pk_expr()).execute() 6555 elif pk_field is not None: 6556 pk = self.insert(**field_dict).execute() 6557 if pk is not None and (self._meta.auto_increment or 6558 pk_value is None): 6559 self._pk = pk 6560 else: 6561 self.insert(**field_dict).execute() 6562 6563 self._dirty.clear() 6564 return rows 6565 6566 def is_dirty(self): 6567 return bool(self._dirty) 6568 6569 @property 6570 def dirty_fields(self): 6571 return [f for f in self._meta.sorted_fields if f.name in self._dirty] 6572 6573 def dependencies(self, search_nullable=False): 6574 model_class = type(self) 6575 stack = [(type(self), None)] 6576 seen = set() 6577 6578 while stack: 6579 klass, query = stack.pop() 6580 if klass in seen: 6581 continue 6582 seen.add(klass) 6583 for fk, rel_model in klass._meta.backrefs.items(): 6584 if rel_model is model_class or query is None: 6585 node = (fk == self.__data__[fk.rel_field.name]) 6586 else: 6587 node = fk << query 6588 subquery = (rel_model.select(rel_model._meta.primary_key) 6589 .where(node)) 6590 if not fk.null or search_nullable: 6591 stack.append((rel_model, subquery)) 6592 yield (node, fk) 6593 6594 def delete_instance(self, recursive=False, delete_nullable=False): 6595 if recursive: 6596 dependencies = self.dependencies(delete_nullable) 6597 for query, fk in reversed(list(dependencies)): 6598 model = fk.model 6599 if fk.null and not delete_nullable: 6600 model.update(**{fk.name: None}).where(query).execute() 6601 else: 6602 model.delete().where(query).execute() 6603 return type(self).delete().where(self._pk_expr()).execute() 6604 6605 def __hash__(self): 6606 return hash((self.__class__, self._pk)) 6607 6608 def __eq__(self, other): 6609 return ( 6610 other.__class__ == self.__class__ and 6611 self._pk is not None and 6612 self._pk == other._pk) 6613 6614 def __ne__(self, other): 6615 return not self == other 6616 6617 def __sql__(self, ctx): 6618 # NOTE: when comparing a foreign-key field whose related-field is not a 6619 # primary-key, then doing an equality test for the foreign-key with a 6620 # model instance will return the wrong value; since we would return 6621 # the primary key for a given model instance. 6622 # 6623 # This checks to see if we have a converter in the scope, and that we 6624 # are converting a foreign-key expression. If so, we hand the model 6625 # instance to the converter rather than blindly grabbing the primary- 6626 # key. In the event the provided converter fails to handle the model 6627 # instance, then we will return the primary-key. 6628 if ctx.state.converter is not None and ctx.state.is_fk_expr: 6629 try: 6630 return ctx.sql(Value(self, converter=ctx.state.converter)) 6631 except (TypeError, ValueError): 6632 pass 6633 6634 return ctx.sql(Value(getattr(self, self._meta.primary_key.name), 6635 converter=self._meta.primary_key.db_value)) 6636 6637 @classmethod 6638 def bind(cls, database, bind_refs=True, bind_backrefs=True, _exclude=None): 6639 is_different = cls._meta.database is not database 6640 cls._meta.set_database(database) 6641 if bind_refs or bind_backrefs: 6642 if _exclude is None: 6643 _exclude = set() 6644 G = cls._meta.model_graph(refs=bind_refs, backrefs=bind_backrefs) 6645 for _, model, is_backref in G: 6646 if model not in _exclude: 6647 model._meta.set_database(database) 6648 _exclude.add(model) 6649 return is_different 6650 6651 @classmethod 6652 def bind_ctx(cls, database, bind_refs=True, bind_backrefs=True): 6653 return _BoundModelsContext((cls,), database, bind_refs, bind_backrefs) 6654 6655 @classmethod 6656 def table_exists(cls): 6657 M = cls._meta 6658 return cls._schema.database.table_exists(M.table.__name__, M.schema) 6659 6660 @classmethod 6661 def create_table(cls, safe=True, **options): 6662 if 'fail_silently' in options: 6663 __deprecated__('"fail_silently" has been deprecated in favor of ' 6664 '"safe" for the create_table() method.') 6665 safe = options.pop('fail_silently') 6666 6667 if safe and not cls._schema.database.safe_create_index \ 6668 and cls.table_exists(): 6669 return 6670 if cls._meta.temporary: 6671 options.setdefault('temporary', cls._meta.temporary) 6672 cls._schema.create_all(safe, **options) 6673 6674 @classmethod 6675 def drop_table(cls, safe=True, drop_sequences=True, **options): 6676 if safe and not cls._schema.database.safe_drop_index \ 6677 and not cls.table_exists(): 6678 return 6679 if cls._meta.temporary: 6680 options.setdefault('temporary', cls._meta.temporary) 6681 cls._schema.drop_all(safe, drop_sequences, **options) 6682 6683 @classmethod 6684 def truncate_table(cls, **options): 6685 cls._schema.truncate_table(**options) 6686 6687 @classmethod 6688 def index(cls, *fields, **kwargs): 6689 return ModelIndex(cls, fields, **kwargs) 6690 6691 @classmethod 6692 def add_index(cls, *fields, **kwargs): 6693 if len(fields) == 1 and isinstance(fields[0], (SQL, Index)): 6694 cls._meta.indexes.append(fields[0]) 6695 else: 6696 cls._meta.indexes.append(ModelIndex(cls, fields, **kwargs)) 6697 6698 6699class ModelAlias(Node): 6700 """Provide a separate reference to a model in a query.""" 6701 def __init__(self, model, alias=None): 6702 self.__dict__['model'] = model 6703 self.__dict__['alias'] = alias 6704 6705 def __getattr__(self, attr): 6706 # Hack to work-around the fact that properties or other objects 6707 # implementing the descriptor protocol (on the model being aliased), 6708 # will not work correctly when we use getattr(). So we explicitly pass 6709 # the model alias to the descriptor's getter. 6710 try: 6711 obj = self.model.__dict__[attr] 6712 except KeyError: 6713 pass 6714 else: 6715 if isinstance(obj, ModelDescriptor): 6716 return obj.__get__(None, self) 6717 6718 model_attr = getattr(self.model, attr) 6719 if isinstance(model_attr, Field): 6720 self.__dict__[attr] = FieldAlias.create(self, model_attr) 6721 return self.__dict__[attr] 6722 return model_attr 6723 6724 def __setattr__(self, attr, value): 6725 raise AttributeError('Cannot set attributes on model aliases.') 6726 6727 def get_field_aliases(self): 6728 return [getattr(self, n) for n in self.model._meta.sorted_field_names] 6729 6730 def select(self, *selection): 6731 if not selection: 6732 selection = self.get_field_aliases() 6733 return ModelSelect(self, selection) 6734 6735 def __call__(self, **kwargs): 6736 return self.model(**kwargs) 6737 6738 def __sql__(self, ctx): 6739 if ctx.scope == SCOPE_VALUES: 6740 # Return the quoted table name. 6741 return ctx.sql(self.model) 6742 6743 if self.alias: 6744 ctx.alias_manager[self] = self.alias 6745 6746 if ctx.scope == SCOPE_SOURCE: 6747 # Define the table and its alias. 6748 return (ctx 6749 .sql(self.model._meta.entity) 6750 .literal(' AS ') 6751 .sql(Entity(ctx.alias_manager[self]))) 6752 else: 6753 # Refer to the table using the alias. 6754 return ctx.sql(Entity(ctx.alias_manager[self])) 6755 6756 6757class FieldAlias(Field): 6758 def __init__(self, source, field): 6759 self.source = source 6760 self.model = source.model 6761 self.field = field 6762 6763 @classmethod 6764 def create(cls, source, field): 6765 class _FieldAlias(cls, type(field)): 6766 pass 6767 return _FieldAlias(source, field) 6768 6769 def clone(self): 6770 return FieldAlias(self.source, self.field) 6771 6772 def adapt(self, value): return self.field.adapt(value) 6773 def python_value(self, value): return self.field.python_value(value) 6774 def db_value(self, value): return self.field.db_value(value) 6775 def __getattr__(self, attr): 6776 return self.source if attr == 'model' else getattr(self.field, attr) 6777 6778 def __sql__(self, ctx): 6779 return ctx.sql(Column(self.source, self.field.column_name)) 6780 6781 6782def sort_models(models): 6783 models = set(models) 6784 seen = set() 6785 ordering = [] 6786 def dfs(model): 6787 if model in models and model not in seen: 6788 seen.add(model) 6789 for foreign_key, rel_model in model._meta.refs.items(): 6790 # Do not depth-first search deferred foreign-keys as this can 6791 # cause tables to be created in the incorrect order. 6792 if not foreign_key.deferred: 6793 dfs(rel_model) 6794 if model._meta.depends_on: 6795 for dependency in model._meta.depends_on: 6796 dfs(dependency) 6797 ordering.append(model) 6798 6799 names = lambda m: (m._meta.name, m._meta.table_name) 6800 for m in sorted(models, key=names): 6801 dfs(m) 6802 return ordering 6803 6804 6805class _ModelQueryHelper(object): 6806 default_row_type = ROW.MODEL 6807 6808 def __init__(self, *args, **kwargs): 6809 super(_ModelQueryHelper, self).__init__(*args, **kwargs) 6810 if not self._database: 6811 self._database = self.model._meta.database 6812 6813 @Node.copy 6814 def objects(self, constructor=None): 6815 self._row_type = ROW.CONSTRUCTOR 6816 self._constructor = self.model if constructor is None else constructor 6817 6818 def _get_cursor_wrapper(self, cursor): 6819 row_type = self._row_type or self.default_row_type 6820 if row_type == ROW.MODEL: 6821 return self._get_model_cursor_wrapper(cursor) 6822 elif row_type == ROW.DICT: 6823 return ModelDictCursorWrapper(cursor, self.model, self._returning) 6824 elif row_type == ROW.TUPLE: 6825 return ModelTupleCursorWrapper(cursor, self.model, self._returning) 6826 elif row_type == ROW.NAMED_TUPLE: 6827 return ModelNamedTupleCursorWrapper(cursor, self.model, 6828 self._returning) 6829 elif row_type == ROW.CONSTRUCTOR: 6830 return ModelObjectCursorWrapper(cursor, self.model, 6831 self._returning, self._constructor) 6832 else: 6833 raise ValueError('Unrecognized row type: "%s".' % row_type) 6834 6835 def _get_model_cursor_wrapper(self, cursor): 6836 return ModelObjectCursorWrapper(cursor, self.model, [], self.model) 6837 6838 6839class ModelRaw(_ModelQueryHelper, RawQuery): 6840 def __init__(self, model, sql, params, **kwargs): 6841 self.model = model 6842 self._returning = () 6843 super(ModelRaw, self).__init__(sql=sql, params=params, **kwargs) 6844 6845 def get(self): 6846 try: 6847 return self.execute()[0] 6848 except IndexError: 6849 sql, params = self.sql() 6850 raise self.model.DoesNotExist('%s instance matching query does ' 6851 'not exist:\nSQL: %s\nParams: %s' % 6852 (self.model, sql, params)) 6853 6854 6855class BaseModelSelect(_ModelQueryHelper): 6856 def union_all(self, rhs): 6857 return ModelCompoundSelectQuery(self.model, self, 'UNION ALL', rhs) 6858 __add__ = union_all 6859 6860 def union(self, rhs): 6861 return ModelCompoundSelectQuery(self.model, self, 'UNION', rhs) 6862 __or__ = union 6863 6864 def intersect(self, rhs): 6865 return ModelCompoundSelectQuery(self.model, self, 'INTERSECT', rhs) 6866 __and__ = intersect 6867 6868 def except_(self, rhs): 6869 return ModelCompoundSelectQuery(self.model, self, 'EXCEPT', rhs) 6870 __sub__ = except_ 6871 6872 def __iter__(self): 6873 if not self._cursor_wrapper: 6874 self.execute() 6875 return iter(self._cursor_wrapper) 6876 6877 def prefetch(self, *subqueries): 6878 return prefetch(self, *subqueries) 6879 6880 def get(self, database=None): 6881 clone = self.paginate(1, 1) 6882 clone._cursor_wrapper = None 6883 try: 6884 return clone.execute(database)[0] 6885 except IndexError: 6886 sql, params = clone.sql() 6887 raise self.model.DoesNotExist('%s instance matching query does ' 6888 'not exist:\nSQL: %s\nParams: %s' % 6889 (clone.model, sql, params)) 6890 6891 @Node.copy 6892 def group_by(self, *columns): 6893 grouping = [] 6894 for column in columns: 6895 if is_model(column): 6896 grouping.extend(column._meta.sorted_fields) 6897 elif isinstance(column, Table): 6898 if not column._columns: 6899 raise ValueError('Cannot pass a table to group_by() that ' 6900 'does not have columns explicitly ' 6901 'declared.') 6902 grouping.extend([getattr(column, col_name) 6903 for col_name in column._columns]) 6904 else: 6905 grouping.append(column) 6906 self._group_by = grouping 6907 6908 6909class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery): 6910 def __init__(self, model, *args, **kwargs): 6911 self.model = model 6912 super(ModelCompoundSelectQuery, self).__init__(*args, **kwargs) 6913 6914 def _get_model_cursor_wrapper(self, cursor): 6915 return self.lhs._get_model_cursor_wrapper(cursor) 6916 6917 6918def _normalize_model_select(fields_or_models): 6919 fields = [] 6920 for fm in fields_or_models: 6921 if is_model(fm): 6922 fields.extend(fm._meta.sorted_fields) 6923 elif isinstance(fm, ModelAlias): 6924 fields.extend(fm.get_field_aliases()) 6925 elif isinstance(fm, Table) and fm._columns: 6926 fields.extend([getattr(fm, col) for col in fm._columns]) 6927 else: 6928 fields.append(fm) 6929 return fields 6930 6931 6932class ModelSelect(BaseModelSelect, Select): 6933 def __init__(self, model, fields_or_models, is_default=False): 6934 self.model = self._join_ctx = model 6935 self._joins = {} 6936 self._is_default = is_default 6937 fields = _normalize_model_select(fields_or_models) 6938 super(ModelSelect, self).__init__([model], fields) 6939 6940 def clone(self): 6941 clone = super(ModelSelect, self).clone() 6942 if clone._joins: 6943 clone._joins = dict(clone._joins) 6944 return clone 6945 6946 def select(self, *fields_or_models): 6947 if fields_or_models or not self._is_default: 6948 self._is_default = False 6949 fields = _normalize_model_select(fields_or_models) 6950 return super(ModelSelect, self).select(*fields) 6951 return self 6952 6953 def switch(self, ctx=None): 6954 self._join_ctx = self.model if ctx is None else ctx 6955 return self 6956 6957 def _get_model(self, src): 6958 if is_model(src): 6959 return src, True 6960 elif isinstance(src, Table) and src._model: 6961 return src._model, False 6962 elif isinstance(src, ModelAlias): 6963 return src.model, False 6964 elif isinstance(src, ModelSelect): 6965 return src.model, False 6966 return None, False 6967 6968 def _normalize_join(self, src, dest, on, attr): 6969 # Allow "on" expression to have an alias that determines the 6970 # destination attribute for the joined data. 6971 on_alias = isinstance(on, Alias) 6972 if on_alias: 6973 attr = attr or on._alias 6974 on = on.alias() 6975 6976 # Obtain references to the source and destination models being joined. 6977 src_model, src_is_model = self._get_model(src) 6978 dest_model, dest_is_model = self._get_model(dest) 6979 6980 if src_model and dest_model: 6981 self._join_ctx = dest 6982 constructor = dest_model 6983 6984 # In the case where the "on" clause is a Column or Field, we will 6985 # convert that field into the appropriate predicate expression. 6986 if not (src_is_model and dest_is_model) and isinstance(on, Column): 6987 if on.source is src: 6988 to_field = src_model._meta.columns[on.name] 6989 elif on.source is dest: 6990 to_field = dest_model._meta.columns[on.name] 6991 else: 6992 raise AttributeError('"on" clause Column %s does not ' 6993 'belong to %s or %s.' % 6994 (on, src_model, dest_model)) 6995 on = None 6996 elif isinstance(on, Field): 6997 to_field = on 6998 on = None 6999 else: 7000 to_field = None 7001 7002 fk_field, is_backref = self._generate_on_clause( 7003 src_model, dest_model, to_field, on) 7004 7005 if on is None: 7006 src_attr = 'name' if src_is_model else 'column_name' 7007 dest_attr = 'name' if dest_is_model else 'column_name' 7008 if is_backref: 7009 lhs = getattr(dest, getattr(fk_field, dest_attr)) 7010 rhs = getattr(src, getattr(fk_field.rel_field, src_attr)) 7011 else: 7012 lhs = getattr(src, getattr(fk_field, src_attr)) 7013 rhs = getattr(dest, getattr(fk_field.rel_field, dest_attr)) 7014 on = (lhs == rhs) 7015 7016 if not attr: 7017 if fk_field is not None and not is_backref: 7018 attr = fk_field.name 7019 else: 7020 attr = dest_model._meta.name 7021 elif on_alias and fk_field is not None and \ 7022 attr == fk_field.object_id_name and not is_backref: 7023 raise ValueError('Cannot assign join alias to "%s", as this ' 7024 'attribute is the object_id_name for the ' 7025 'foreign-key field "%s"' % (attr, fk_field)) 7026 7027 elif isinstance(dest, Source): 7028 constructor = dict 7029 attr = attr or dest._alias 7030 if not attr and isinstance(dest, Table): 7031 attr = attr or dest.__name__ 7032 7033 return (on, attr, constructor) 7034 7035 def _generate_on_clause(self, src, dest, to_field=None, on=None): 7036 meta = src._meta 7037 is_backref = fk_fields = False 7038 7039 # Get all the foreign keys between source and dest, and determine if 7040 # the join is via a back-reference. 7041 if dest in meta.model_refs: 7042 fk_fields = meta.model_refs[dest] 7043 elif dest in meta.model_backrefs: 7044 fk_fields = meta.model_backrefs[dest] 7045 is_backref = True 7046 7047 if not fk_fields: 7048 if on is not None: 7049 return None, False 7050 raise ValueError('Unable to find foreign key between %s and %s. ' 7051 'Please specify an explicit join condition.' % 7052 (src, dest)) 7053 elif to_field is not None: 7054 # If the foreign-key field was specified explicitly, remove all 7055 # other foreign-key fields from the list. 7056 target = (to_field.field if isinstance(to_field, FieldAlias) 7057 else to_field) 7058 fk_fields = [f for f in fk_fields if ( 7059 (f is target) or 7060 (is_backref and f.rel_field is to_field))] 7061 7062 if len(fk_fields) == 1: 7063 return fk_fields[0], is_backref 7064 7065 if on is None: 7066 # If multiple foreign-keys exist, try using the FK whose name 7067 # matches that of the related model. If not, raise an error as this 7068 # is ambiguous. 7069 for fk in fk_fields: 7070 if fk.name == dest._meta.name: 7071 return fk, is_backref 7072 7073 raise ValueError('More than one foreign key between %s and %s.' 7074 ' Please specify which you are joining on.' % 7075 (src, dest)) 7076 7077 # If there are multiple foreign-keys to choose from and the join 7078 # predicate is an expression, we'll try to figure out which 7079 # foreign-key field we're joining on so that we can assign to the 7080 # correct attribute when resolving the model graph. 7081 to_field = None 7082 if isinstance(on, Expression): 7083 lhs, rhs = on.lhs, on.rhs 7084 # Coerce to set() so that we force Python to compare using the 7085 # object's hash rather than equality test, which returns a 7086 # false-positive due to overriding __eq__. 7087 fk_set = set(fk_fields) 7088 7089 if isinstance(lhs, Field): 7090 lhs_f = lhs.field if isinstance(lhs, FieldAlias) else lhs 7091 if lhs_f in fk_set: 7092 to_field = lhs_f 7093 elif isinstance(rhs, Field): 7094 rhs_f = rhs.field if isinstance(rhs, FieldAlias) else rhs 7095 if rhs_f in fk_set: 7096 to_field = rhs_f 7097 7098 return to_field, False 7099 7100 @Node.copy 7101 def join(self, dest, join_type=JOIN.INNER, on=None, src=None, attr=None): 7102 src = self._join_ctx if src is None else src 7103 7104 if join_type == JOIN.LATERAL or join_type == JOIN.LEFT_LATERAL: 7105 on = True 7106 elif join_type != JOIN.CROSS: 7107 on, attr, constructor = self._normalize_join(src, dest, on, attr) 7108 if attr: 7109 self._joins.setdefault(src, []) 7110 self._joins[src].append((dest, attr, constructor, join_type)) 7111 elif on is not None: 7112 raise ValueError('Cannot specify on clause with cross join.') 7113 7114 if not self._from_list: 7115 raise ValueError('No sources to join on.') 7116 7117 item = self._from_list.pop() 7118 self._from_list.append(Join(item, dest, join_type, on)) 7119 7120 def join_from(self, src, dest, join_type=JOIN.INNER, on=None, attr=None): 7121 return self.join(dest, join_type, on, src, attr) 7122 7123 def _get_model_cursor_wrapper(self, cursor): 7124 if len(self._from_list) == 1 and not self._joins: 7125 return ModelObjectCursorWrapper(cursor, self.model, 7126 self._returning, self.model) 7127 return ModelCursorWrapper(cursor, self.model, self._returning, 7128 self._from_list, self._joins) 7129 7130 def ensure_join(self, lm, rm, on=None, **join_kwargs): 7131 join_ctx = self._join_ctx 7132 for dest, _, constructor, _ in self._joins.get(lm, []): 7133 if dest == rm: 7134 return self 7135 return self.switch(lm).join(rm, on=on, **join_kwargs).switch(join_ctx) 7136 7137 def convert_dict_to_node(self, qdict): 7138 accum = [] 7139 joins = [] 7140 fks = (ForeignKeyField, BackrefAccessor) 7141 for key, value in sorted(qdict.items()): 7142 curr = self.model 7143 if '__' in key and key.rsplit('__', 1)[1] in DJANGO_MAP: 7144 key, op = key.rsplit('__', 1) 7145 op = DJANGO_MAP[op] 7146 elif value is None: 7147 op = DJANGO_MAP['is'] 7148 else: 7149 op = DJANGO_MAP['eq'] 7150 7151 if '__' not in key: 7152 # Handle simplest case. This avoids joining over-eagerly when a 7153 # direct FK lookup is all that is required. 7154 model_attr = getattr(curr, key) 7155 else: 7156 for piece in key.split('__'): 7157 for dest, attr, _, _ in self._joins.get(curr, ()): 7158 if attr == piece or (isinstance(dest, ModelAlias) and 7159 dest.alias == piece): 7160 curr = dest 7161 break 7162 else: 7163 model_attr = getattr(curr, piece) 7164 if value is not None and isinstance(model_attr, fks): 7165 curr = model_attr.rel_model 7166 joins.append(model_attr) 7167 accum.append(op(model_attr, value)) 7168 return accum, joins 7169 7170 def filter(self, *args, **kwargs): 7171 # normalize args and kwargs into a new expression 7172 if args and kwargs: 7173 dq_node = (reduce(operator.and_, [a.clone() for a in args]) & 7174 DQ(**kwargs)) 7175 elif args: 7176 dq_node = (reduce(operator.and_, [a.clone() for a in args]) & 7177 ColumnBase()) 7178 elif kwargs: 7179 dq_node = DQ(**kwargs) & ColumnBase() 7180 else: 7181 return self.clone() 7182 7183 # dq_node should now be an Expression, lhs = Node(), rhs = ... 7184 q = collections.deque([dq_node]) 7185 dq_joins = [] 7186 seen_joins = set() 7187 while q: 7188 curr = q.popleft() 7189 if not isinstance(curr, Expression): 7190 continue 7191 for side, piece in (('lhs', curr.lhs), ('rhs', curr.rhs)): 7192 if isinstance(piece, DQ): 7193 query, joins = self.convert_dict_to_node(piece.query) 7194 for join in joins: 7195 if join not in seen_joins: 7196 dq_joins.append(join) 7197 seen_joins.add(join) 7198 expression = reduce(operator.and_, query) 7199 # Apply values from the DQ object. 7200 if piece._negated: 7201 expression = Negated(expression) 7202 #expression._alias = piece._alias 7203 setattr(curr, side, expression) 7204 else: 7205 q.append(piece) 7206 7207 if not args or not kwargs: 7208 dq_node = dq_node.lhs 7209 7210 query = self.clone() 7211 for field in dq_joins: 7212 if isinstance(field, ForeignKeyField): 7213 lm, rm = field.model, field.rel_model 7214 field_obj = field 7215 elif isinstance(field, BackrefAccessor): 7216 lm, rm = field.model, field.rel_model 7217 field_obj = field.field 7218 query = query.ensure_join(lm, rm, field_obj) 7219 return query.where(dq_node) 7220 7221 def create_table(self, name, safe=True, **meta): 7222 return self.model._schema.create_table_as(name, self, safe, **meta) 7223 7224 def __sql_selection__(self, ctx, is_subquery=False): 7225 if self._is_default and is_subquery and len(self._returning) > 1 and \ 7226 self.model._meta.primary_key is not False: 7227 return ctx.sql(self.model._meta.primary_key) 7228 7229 return ctx.sql(CommaNodeList(self._returning)) 7230 7231 7232class NoopModelSelect(ModelSelect): 7233 def __sql__(self, ctx): 7234 return self.model._meta.database.get_noop_select(ctx) 7235 7236 def _get_cursor_wrapper(self, cursor): 7237 return CursorWrapper(cursor) 7238 7239 7240class _ModelWriteQueryHelper(_ModelQueryHelper): 7241 def __init__(self, model, *args, **kwargs): 7242 self.model = model 7243 super(_ModelWriteQueryHelper, self).__init__(model, *args, **kwargs) 7244 7245 def returning(self, *returning): 7246 accum = [] 7247 for item in returning: 7248 if is_model(item): 7249 accum.extend(item._meta.sorted_fields) 7250 else: 7251 accum.append(item) 7252 return super(_ModelWriteQueryHelper, self).returning(*accum) 7253 7254 def _set_table_alias(self, ctx): 7255 table = self.model._meta.table 7256 ctx.alias_manager[table] = table.__name__ 7257 7258 7259class ModelUpdate(_ModelWriteQueryHelper, Update): 7260 pass 7261 7262 7263class ModelInsert(_ModelWriteQueryHelper, Insert): 7264 default_row_type = ROW.TUPLE 7265 7266 def __init__(self, *args, **kwargs): 7267 super(ModelInsert, self).__init__(*args, **kwargs) 7268 if self._returning is None and self.model._meta.database is not None: 7269 if self.model._meta.database.returning_clause: 7270 self._returning = self.model._meta.get_primary_keys() 7271 7272 def returning(self, *returning): 7273 # By default ModelInsert will yield a `tuple` containing the 7274 # primary-key of the newly inserted row. But if we are explicitly 7275 # specifying a returning clause and have not set a row type, we will 7276 # default to returning model instances instead. 7277 if returning and self._row_type is None: 7278 self._row_type = ROW.MODEL 7279 return super(ModelInsert, self).returning(*returning) 7280 7281 def get_default_data(self): 7282 return self.model._meta.defaults 7283 7284 def get_default_columns(self): 7285 fields = self.model._meta.sorted_fields 7286 return fields[1:] if self.model._meta.auto_increment else fields 7287 7288 7289class ModelDelete(_ModelWriteQueryHelper, Delete): 7290 pass 7291 7292 7293class ManyToManyQuery(ModelSelect): 7294 def __init__(self, instance, accessor, rel, *args, **kwargs): 7295 self._instance = instance 7296 self._accessor = accessor 7297 self._src_attr = accessor.src_fk.rel_field.name 7298 self._dest_attr = accessor.dest_fk.rel_field.name 7299 super(ManyToManyQuery, self).__init__(rel, (rel,), *args, **kwargs) 7300 7301 def _id_list(self, model_or_id_list): 7302 if isinstance(model_or_id_list[0], Model): 7303 return [getattr(obj, self._dest_attr) for obj in model_or_id_list] 7304 return model_or_id_list 7305 7306 def add(self, value, clear_existing=False): 7307 if clear_existing: 7308 self.clear() 7309 7310 accessor = self._accessor 7311 src_id = getattr(self._instance, self._src_attr) 7312 if isinstance(value, SelectQuery): 7313 query = value.columns( 7314 Value(src_id), 7315 accessor.dest_fk.rel_field) 7316 accessor.through_model.insert_from( 7317 fields=[accessor.src_fk, accessor.dest_fk], 7318 query=query).execute() 7319 else: 7320 value = ensure_tuple(value) 7321 if not value: return 7322 7323 inserts = [{ 7324 accessor.src_fk.name: src_id, 7325 accessor.dest_fk.name: rel_id} 7326 for rel_id in self._id_list(value)] 7327 accessor.through_model.insert_many(inserts).execute() 7328 7329 def remove(self, value): 7330 src_id = getattr(self._instance, self._src_attr) 7331 if isinstance(value, SelectQuery): 7332 column = getattr(value.model, self._dest_attr) 7333 subquery = value.columns(column) 7334 return (self._accessor.through_model 7335 .delete() 7336 .where( 7337 (self._accessor.dest_fk << subquery) & 7338 (self._accessor.src_fk == src_id)) 7339 .execute()) 7340 else: 7341 value = ensure_tuple(value) 7342 if not value: 7343 return 7344 return (self._accessor.through_model 7345 .delete() 7346 .where( 7347 (self._accessor.dest_fk << self._id_list(value)) & 7348 (self._accessor.src_fk == src_id)) 7349 .execute()) 7350 7351 def clear(self): 7352 src_id = getattr(self._instance, self._src_attr) 7353 return (self._accessor.through_model 7354 .delete() 7355 .where(self._accessor.src_fk == src_id) 7356 .execute()) 7357 7358 7359def safe_python_value(conv_func): 7360 def validate(value): 7361 try: 7362 return conv_func(value) 7363 except (TypeError, ValueError): 7364 return value 7365 return validate 7366 7367 7368class BaseModelCursorWrapper(DictCursorWrapper): 7369 def __init__(self, cursor, model, columns): 7370 super(BaseModelCursorWrapper, self).__init__(cursor) 7371 self.model = model 7372 self.select = columns or [] 7373 7374 def _initialize_columns(self): 7375 combined = self.model._meta.combined 7376 table = self.model._meta.table 7377 description = self.cursor.description 7378 7379 self.ncols = len(self.cursor.description) 7380 self.columns = [] 7381 self.converters = converters = [None] * self.ncols 7382 self.fields = fields = [None] * self.ncols 7383 7384 for idx, description_item in enumerate(description): 7385 column = description_item[0] 7386 dot_index = column.find('.') 7387 if dot_index != -1: 7388 column = column[dot_index + 1:] 7389 7390 column = column.strip('")') 7391 self.columns.append(column) 7392 try: 7393 raw_node = self.select[idx] 7394 except IndexError: 7395 if column in combined: 7396 raw_node = node = combined[column] 7397 else: 7398 continue 7399 else: 7400 node = raw_node.unwrap() 7401 7402 # Heuristics used to attempt to get the field associated with a 7403 # given SELECT column, so that we can accurately convert the value 7404 # returned by the database-cursor into a Python object. 7405 if isinstance(node, Field): 7406 if raw_node._coerce: 7407 converters[idx] = node.python_value 7408 fields[idx] = node 7409 if not raw_node.is_alias(): 7410 self.columns[idx] = node.name 7411 elif isinstance(node, ColumnBase) and raw_node._converter: 7412 converters[idx] = raw_node._converter 7413 elif isinstance(node, Function) and node._coerce: 7414 if node._python_value is not None: 7415 converters[idx] = node._python_value 7416 elif node.arguments and isinstance(node.arguments[0], Node): 7417 # If the first argument is a field or references a column 7418 # on a Model, try using that field's conversion function. 7419 # This usually works, but we use "safe_python_value()" so 7420 # that if a TypeError or ValueError occurs during 7421 # conversion we can just fall-back to the raw cursor value. 7422 first = node.arguments[0].unwrap() 7423 if isinstance(first, Entity): 7424 path = first._path[-1] # Try to look-up by name. 7425 first = combined.get(path) 7426 if isinstance(first, Field): 7427 converters[idx] = safe_python_value(first.python_value) 7428 elif column in combined: 7429 if node._coerce: 7430 converters[idx] = combined[column].python_value 7431 if isinstance(node, Column) and node.source == table: 7432 fields[idx] = combined[column] 7433 7434 initialize = _initialize_columns 7435 7436 def process_row(self, row): 7437 raise NotImplementedError 7438 7439 7440class ModelDictCursorWrapper(BaseModelCursorWrapper): 7441 def process_row(self, row): 7442 result = {} 7443 columns, converters = self.columns, self.converters 7444 fields = self.fields 7445 7446 for i in range(self.ncols): 7447 attr = columns[i] 7448 if attr in result: continue # Don't overwrite if we have dupes. 7449 if converters[i] is not None: 7450 result[attr] = converters[i](row[i]) 7451 else: 7452 result[attr] = row[i] 7453 7454 return result 7455 7456 7457class ModelTupleCursorWrapper(ModelDictCursorWrapper): 7458 constructor = tuple 7459 7460 def process_row(self, row): 7461 columns, converters = self.columns, self.converters 7462 return self.constructor([ 7463 (converters[i](row[i]) if converters[i] is not None else row[i]) 7464 for i in range(self.ncols)]) 7465 7466 7467class ModelNamedTupleCursorWrapper(ModelTupleCursorWrapper): 7468 def initialize(self): 7469 self._initialize_columns() 7470 attributes = [] 7471 for i in range(self.ncols): 7472 attributes.append(self.columns[i]) 7473 self.tuple_class = collections.namedtuple('Row', attributes) 7474 self.constructor = lambda row: self.tuple_class(*row) 7475 7476 7477class ModelObjectCursorWrapper(ModelDictCursorWrapper): 7478 def __init__(self, cursor, model, select, constructor): 7479 self.constructor = constructor 7480 self.is_model = is_model(constructor) 7481 super(ModelObjectCursorWrapper, self).__init__(cursor, model, select) 7482 7483 def process_row(self, row): 7484 data = super(ModelObjectCursorWrapper, self).process_row(row) 7485 if self.is_model: 7486 # Clear out any dirty fields before returning to the user. 7487 obj = self.constructor(__no_default__=1, **data) 7488 obj._dirty.clear() 7489 return obj 7490 else: 7491 return self.constructor(**data) 7492 7493 7494class ModelCursorWrapper(BaseModelCursorWrapper): 7495 def __init__(self, cursor, model, select, from_list, joins): 7496 super(ModelCursorWrapper, self).__init__(cursor, model, select) 7497 self.from_list = from_list 7498 self.joins = joins 7499 7500 def initialize(self): 7501 self._initialize_columns() 7502 selected_src = set([field.model for field in self.fields 7503 if field is not None]) 7504 select, columns = self.select, self.columns 7505 7506 self.key_to_constructor = {self.model: self.model} 7507 self.src_is_dest = {} 7508 self.src_to_dest = [] 7509 accum = collections.deque(self.from_list) 7510 dests = set() 7511 7512 while accum: 7513 curr = accum.popleft() 7514 if isinstance(curr, Join): 7515 accum.append(curr.lhs) 7516 accum.append(curr.rhs) 7517 continue 7518 7519 if curr not in self.joins: 7520 continue 7521 7522 is_dict = isinstance(curr, dict) 7523 for key, attr, constructor, join_type in self.joins[curr]: 7524 if key not in self.key_to_constructor: 7525 self.key_to_constructor[key] = constructor 7526 7527 # (src, attr, dest, is_dict, join_type). 7528 self.src_to_dest.append((curr, attr, key, is_dict, 7529 join_type)) 7530 dests.add(key) 7531 accum.append(key) 7532 7533 # Ensure that we accommodate everything selected. 7534 for src in selected_src: 7535 if src not in self.key_to_constructor: 7536 if is_model(src): 7537 self.key_to_constructor[src] = src 7538 elif isinstance(src, ModelAlias): 7539 self.key_to_constructor[src] = src.model 7540 7541 # Indicate which sources are also dests. 7542 for src, _, dest, _, _ in self.src_to_dest: 7543 self.src_is_dest[src] = src in dests and (dest in selected_src 7544 or src in selected_src) 7545 7546 self.column_keys = [] 7547 for idx, node in enumerate(select): 7548 key = self.model 7549 field = self.fields[idx] 7550 if field is not None: 7551 if isinstance(field, FieldAlias): 7552 key = field.source 7553 else: 7554 key = field.model 7555 else: 7556 if isinstance(node, Node): 7557 node = node.unwrap() 7558 if isinstance(node, Column): 7559 key = node.source 7560 7561 self.column_keys.append(key) 7562 7563 def process_row(self, row): 7564 objects = {} 7565 object_list = [] 7566 for key, constructor in self.key_to_constructor.items(): 7567 objects[key] = constructor(__no_default__=True) 7568 object_list.append(objects[key]) 7569 7570 default_instance = objects[self.model] 7571 7572 set_keys = set() 7573 for idx, key in enumerate(self.column_keys): 7574 # Get the instance corresponding to the selected column/value, 7575 # falling back to the "root" model instance. 7576 instance = objects.get(key, default_instance) 7577 column = self.columns[idx] 7578 value = row[idx] 7579 if value is not None: 7580 set_keys.add(key) 7581 if self.converters[idx]: 7582 value = self.converters[idx](value) 7583 7584 if isinstance(instance, dict): 7585 instance[column] = value 7586 else: 7587 setattr(instance, column, value) 7588 7589 # Need to do some analysis on the joins before this. 7590 for (src, attr, dest, is_dict, join_type) in self.src_to_dest: 7591 instance = objects[src] 7592 try: 7593 joined_instance = objects[dest] 7594 except KeyError: 7595 continue 7596 7597 # If no fields were set on the destination instance then do not 7598 # assign an "empty" instance. 7599 if instance is None or dest is None or \ 7600 (dest not in set_keys and not self.src_is_dest.get(dest)): 7601 continue 7602 7603 # If no fields were set on either the source or the destination, 7604 # then we have nothing to do here. 7605 if instance not in set_keys and dest not in set_keys \ 7606 and join_type.endswith('OUTER JOIN'): 7607 continue 7608 7609 if is_dict: 7610 instance[attr] = joined_instance 7611 else: 7612 setattr(instance, attr, joined_instance) 7613 7614 # When instantiating models from a cursor, we clear the dirty fields. 7615 for instance in object_list: 7616 if isinstance(instance, Model): 7617 instance._dirty.clear() 7618 7619 return objects[self.model] 7620 7621 7622class PrefetchQuery(collections.namedtuple('_PrefetchQuery', ( 7623 'query', 'fields', 'is_backref', 'rel_models', 'field_to_name', 'model'))): 7624 def __new__(cls, query, fields=None, is_backref=None, rel_models=None, 7625 field_to_name=None, model=None): 7626 if fields: 7627 if is_backref: 7628 if rel_models is None: 7629 rel_models = [field.model for field in fields] 7630 foreign_key_attrs = [field.rel_field.name for field in fields] 7631 else: 7632 if rel_models is None: 7633 rel_models = [field.rel_model for field in fields] 7634 foreign_key_attrs = [field.name for field in fields] 7635 field_to_name = list(zip(fields, foreign_key_attrs)) 7636 model = query.model 7637 return super(PrefetchQuery, cls).__new__( 7638 cls, query, fields, is_backref, rel_models, field_to_name, model) 7639 7640 def populate_instance(self, instance, id_map): 7641 if self.is_backref: 7642 for field in self.fields: 7643 identifier = instance.__data__[field.name] 7644 key = (field, identifier) 7645 if key in id_map: 7646 setattr(instance, field.name, id_map[key]) 7647 else: 7648 for field, attname in self.field_to_name: 7649 identifier = instance.__data__[field.rel_field.name] 7650 key = (field, identifier) 7651 rel_instances = id_map.get(key, []) 7652 for inst in rel_instances: 7653 setattr(inst, attname, instance) 7654 inst._dirty.clear() 7655 setattr(instance, field.backref, rel_instances) 7656 7657 def store_instance(self, instance, id_map): 7658 for field, attname in self.field_to_name: 7659 identity = field.rel_field.python_value(instance.__data__[attname]) 7660 key = (field, identity) 7661 if self.is_backref: 7662 id_map[key] = instance 7663 else: 7664 id_map.setdefault(key, []) 7665 id_map[key].append(instance) 7666 7667 7668def prefetch_add_subquery(sq, subqueries): 7669 fixed_queries = [PrefetchQuery(sq)] 7670 for i, subquery in enumerate(subqueries): 7671 if isinstance(subquery, tuple): 7672 subquery, target_model = subquery 7673 else: 7674 target_model = None 7675 if not isinstance(subquery, Query) and is_model(subquery) or \ 7676 isinstance(subquery, ModelAlias): 7677 subquery = subquery.select() 7678 subquery_model = subquery.model 7679 fks = backrefs = None 7680 for j in reversed(range(i + 1)): 7681 fixed = fixed_queries[j] 7682 last_query = fixed.query 7683 last_model = last_obj = fixed.model 7684 if isinstance(last_model, ModelAlias): 7685 last_model = last_model.model 7686 rels = subquery_model._meta.model_refs.get(last_model, []) 7687 if rels: 7688 fks = [getattr(subquery_model, fk.name) for fk in rels] 7689 pks = [getattr(last_obj, fk.rel_field.name) for fk in rels] 7690 else: 7691 backrefs = subquery_model._meta.model_backrefs.get(last_model) 7692 if (fks or backrefs) and ((target_model is last_obj) or 7693 (target_model is None)): 7694 break 7695 7696 if not fks and not backrefs: 7697 tgt_err = ' using %s' % target_model if target_model else '' 7698 raise AttributeError('Error: unable to find foreign key for ' 7699 'query: %s%s' % (subquery, tgt_err)) 7700 7701 dest = (target_model,) if target_model else None 7702 7703 if fks: 7704 expr = reduce(operator.or_, [ 7705 (fk << last_query.select(pk)) 7706 for (fk, pk) in zip(fks, pks)]) 7707 subquery = subquery.where(expr) 7708 fixed_queries.append(PrefetchQuery(subquery, fks, False, dest)) 7709 elif backrefs: 7710 expressions = [] 7711 for backref in backrefs: 7712 rel_field = getattr(subquery_model, backref.rel_field.name) 7713 fk_field = getattr(last_obj, backref.name) 7714 expressions.append(rel_field << last_query.select(fk_field)) 7715 subquery = subquery.where(reduce(operator.or_, expressions)) 7716 fixed_queries.append(PrefetchQuery(subquery, backrefs, True, dest)) 7717 7718 return fixed_queries 7719 7720 7721def prefetch(sq, *subqueries): 7722 if not subqueries: 7723 return sq 7724 7725 fixed_queries = prefetch_add_subquery(sq, subqueries) 7726 deps = {} 7727 rel_map = {} 7728 for pq in reversed(fixed_queries): 7729 query_model = pq.model 7730 if pq.fields: 7731 for rel_model in pq.rel_models: 7732 rel_map.setdefault(rel_model, []) 7733 rel_map[rel_model].append(pq) 7734 7735 deps.setdefault(query_model, {}) 7736 id_map = deps[query_model] 7737 has_relations = bool(rel_map.get(query_model)) 7738 7739 for instance in pq.query: 7740 if pq.fields: 7741 pq.store_instance(instance, id_map) 7742 if has_relations: 7743 for rel in rel_map[query_model]: 7744 rel.populate_instance(instance, deps[rel.model]) 7745 7746 return list(pq.query) 7747