1import json 2import math 3import re 4import struct 5import sys 6 7from peewee import * 8from peewee import ColumnBase 9from peewee import EnclosedNodeList 10from peewee import Entity 11from peewee import Expression 12from peewee import Node 13from peewee import NodeList 14from peewee import OP 15from peewee import VirtualField 16from peewee import merge_dict 17from peewee import sqlite3 18try: 19 from playhouse._sqlite_ext import ( 20 backup, 21 backup_to_file, 22 Blob, 23 ConnectionHelper, 24 register_bloomfilter, 25 register_hash_functions, 26 register_rank_functions, 27 sqlite_get_db_status, 28 sqlite_get_status, 29 TableFunction, 30 ZeroBlob, 31 ) 32 CYTHON_SQLITE_EXTENSIONS = True 33except ImportError: 34 CYTHON_SQLITE_EXTENSIONS = False 35 36 37if sys.version_info[0] == 3: 38 basestring = str 39 40 41FTS3_MATCHINFO = 'pcx' 42FTS4_MATCHINFO = 'pcnalx' 43if sqlite3 is not None: 44 FTS_VERSION = 4 if sqlite3.sqlite_version_info[:3] >= (3, 7, 4) else 3 45else: 46 FTS_VERSION = 3 47 48FTS5_MIN_SQLITE_VERSION = (3, 9, 0) 49 50 51class RowIDField(AutoField): 52 auto_increment = True 53 column_name = name = required_name = 'rowid' 54 55 def bind(self, model, name, *args): 56 if name != self.required_name: 57 raise ValueError('%s must be named "%s".' % 58 (type(self), self.required_name)) 59 super(RowIDField, self).bind(model, name, *args) 60 61 62class DocIDField(RowIDField): 63 column_name = name = required_name = 'docid' 64 65 66class AutoIncrementField(AutoField): 67 def ddl(self, ctx): 68 node_list = super(AutoIncrementField, self).ddl(ctx) 69 return NodeList((node_list, SQL('AUTOINCREMENT'))) 70 71 72class TDecimalField(DecimalField): 73 field_type = 'TEXT' 74 def get_modifiers(self): pass 75 76 77class JSONPath(ColumnBase): 78 def __init__(self, field, path=None): 79 super(JSONPath, self).__init__() 80 self._field = field 81 self._path = path or () 82 83 @property 84 def path(self): 85 return Value('$%s' % ''.join(self._path)) 86 87 def __getitem__(self, idx): 88 if isinstance(idx, int): 89 item = '[%s]' % idx 90 else: 91 item = '.%s' % idx 92 return JSONPath(self._field, self._path + (item,)) 93 94 def set(self, value, as_json=None): 95 if as_json or isinstance(value, (list, dict)): 96 value = fn.json(self._field._json_dumps(value)) 97 return fn.json_set(self._field, self.path, value) 98 99 def update(self, value): 100 return self.set(fn.json_patch(self, self._field._json_dumps(value))) 101 102 def remove(self): 103 return fn.json_remove(self._field, self.path) 104 105 def json_type(self): 106 return fn.json_type(self._field, self.path) 107 108 def length(self): 109 return fn.json_array_length(self._field, self.path) 110 111 def children(self): 112 return fn.json_each(self._field, self.path) 113 114 def tree(self): 115 return fn.json_tree(self._field, self.path) 116 117 def __sql__(self, ctx): 118 return ctx.sql(fn.json_extract(self._field, self.path) 119 if self._path else self._field) 120 121 122class JSONField(TextField): 123 field_type = 'JSON' 124 unpack = False 125 126 def __init__(self, json_dumps=None, json_loads=None, **kwargs): 127 self._json_dumps = json_dumps or json.dumps 128 self._json_loads = json_loads or json.loads 129 super(JSONField, self).__init__(**kwargs) 130 131 def python_value(self, value): 132 if value is not None: 133 try: 134 return self._json_loads(value) 135 except (TypeError, ValueError): 136 return value 137 138 def db_value(self, value): 139 if value is not None: 140 if not isinstance(value, Node): 141 value = fn.json(self._json_dumps(value)) 142 return value 143 144 def _e(op): 145 def inner(self, rhs): 146 if isinstance(rhs, (list, dict)): 147 rhs = Value(rhs, converter=self.db_value, unpack=False) 148 return Expression(self, op, rhs) 149 return inner 150 __eq__ = _e(OP.EQ) 151 __ne__ = _e(OP.NE) 152 __gt__ = _e(OP.GT) 153 __ge__ = _e(OP.GTE) 154 __lt__ = _e(OP.LT) 155 __le__ = _e(OP.LTE) 156 __hash__ = Field.__hash__ 157 158 def __getitem__(self, item): 159 return JSONPath(self)[item] 160 161 def set(self, value, as_json=None): 162 return JSONPath(self).set(value, as_json) 163 164 def update(self, data): 165 return JSONPath(self).update(data) 166 167 def remove(self): 168 return JSONPath(self).remove() 169 170 def json_type(self): 171 return fn.json_type(self) 172 173 def length(self): 174 return fn.json_array_length(self) 175 176 def children(self): 177 """ 178 Schema of `json_each` and `json_tree`: 179 180 key, 181 value, 182 type TEXT (object, array, string, etc), 183 atom (value for primitive/scalar types, NULL for array and object) 184 id INTEGER (unique identifier for element) 185 parent INTEGER (unique identifier of parent element or NULL) 186 fullkey TEXT (full path describing element) 187 path TEXT (path to the container of the current element) 188 json JSON hidden (1st input parameter to function) 189 root TEXT hidden (2nd input parameter, path at which to start) 190 """ 191 return fn.json_each(self) 192 193 def tree(self): 194 return fn.json_tree(self) 195 196 197class SearchField(Field): 198 def __init__(self, unindexed=False, column_name=None, **k): 199 if k: 200 raise ValueError('SearchField does not accept these keyword ' 201 'arguments: %s.' % sorted(k)) 202 super(SearchField, self).__init__(unindexed=unindexed, 203 column_name=column_name, null=True) 204 205 def match(self, term): 206 return match(self, term) 207 208 209class VirtualTableSchemaManager(SchemaManager): 210 def _create_virtual_table(self, safe=True, **options): 211 options = self.model.clean_options( 212 merge_dict(self.model._meta.options, options)) 213 214 # Structure: 215 # CREATE VIRTUAL TABLE <model> 216 # USING <extension_module> 217 # ([prefix_arguments, ...] fields, ... [arguments, ...], [options...]) 218 ctx = self._create_context() 219 ctx.literal('CREATE VIRTUAL TABLE ') 220 if safe: 221 ctx.literal('IF NOT EXISTS ') 222 (ctx 223 .sql(self.model) 224 .literal(' USING ')) 225 226 ext_module = self.model._meta.extension_module 227 if isinstance(ext_module, Node): 228 return ctx.sql(ext_module) 229 230 ctx.sql(SQL(ext_module)).literal(' ') 231 arguments = [] 232 meta = self.model._meta 233 234 if meta.prefix_arguments: 235 arguments.extend([SQL(a) for a in meta.prefix_arguments]) 236 237 # Constraints, data-types, foreign and primary keys are all omitted. 238 for field in meta.sorted_fields: 239 if isinstance(field, (RowIDField)) or field._hidden: 240 continue 241 field_def = [Entity(field.column_name)] 242 if field.unindexed: 243 field_def.append(SQL('UNINDEXED')) 244 arguments.append(NodeList(field_def)) 245 246 if meta.arguments: 247 arguments.extend([SQL(a) for a in meta.arguments]) 248 249 if options: 250 arguments.extend(self._create_table_option_sql(options)) 251 return ctx.sql(EnclosedNodeList(arguments)) 252 253 def _create_table(self, safe=True, **options): 254 if issubclass(self.model, VirtualModel): 255 return self._create_virtual_table(safe, **options) 256 257 return super(VirtualTableSchemaManager, self)._create_table( 258 safe, **options) 259 260 261class VirtualModel(Model): 262 class Meta: 263 arguments = None 264 extension_module = None 265 prefix_arguments = None 266 primary_key = False 267 schema_manager_class = VirtualTableSchemaManager 268 269 @classmethod 270 def clean_options(cls, options): 271 return options 272 273 274class BaseFTSModel(VirtualModel): 275 @classmethod 276 def clean_options(cls, options): 277 content = options.get('content') 278 prefix = options.get('prefix') 279 tokenize = options.get('tokenize') 280 281 if isinstance(content, basestring) and content == '': 282 # Special-case content-less full-text search tables. 283 options['content'] = "''" 284 elif isinstance(content, Field): 285 # Special-case to ensure fields are fully-qualified. 286 options['content'] = Entity(content.model._meta.table_name, 287 content.column_name) 288 289 if prefix: 290 if isinstance(prefix, (list, tuple)): 291 prefix = ','.join([str(i) for i in prefix]) 292 options['prefix'] = "'%s'" % prefix.strip("' ") 293 294 if tokenize and cls._meta.extension_module.lower() == 'fts5': 295 # Tokenizers need to be in quoted string for FTS5, but not for FTS3 296 # or FTS4. 297 options['tokenize'] = '"%s"' % tokenize 298 299 return options 300 301 302class FTSModel(BaseFTSModel): 303 """ 304 VirtualModel class for creating tables that use either the FTS3 or FTS4 305 search extensions. Peewee automatically determines which version of the 306 FTS extension is supported and will use FTS4 if possible. 307 """ 308 # FTS3/4 uses "docid" in the same way a normal table uses "rowid". 309 docid = DocIDField() 310 311 class Meta: 312 extension_module = 'FTS%s' % FTS_VERSION 313 314 @classmethod 315 def _fts_cmd(cls, cmd): 316 tbl = cls._meta.table_name 317 res = cls._meta.database.execute_sql( 318 "INSERT INTO %s(%s) VALUES('%s');" % (tbl, tbl, cmd)) 319 return res.fetchone() 320 321 @classmethod 322 def optimize(cls): 323 return cls._fts_cmd('optimize') 324 325 @classmethod 326 def rebuild(cls): 327 return cls._fts_cmd('rebuild') 328 329 @classmethod 330 def integrity_check(cls): 331 return cls._fts_cmd('integrity-check') 332 333 @classmethod 334 def merge(cls, blocks=200, segments=8): 335 return cls._fts_cmd('merge=%s,%s' % (blocks, segments)) 336 337 @classmethod 338 def automerge(cls, state=True): 339 return cls._fts_cmd('automerge=%s' % (state and '1' or '0')) 340 341 @classmethod 342 def match(cls, term): 343 """ 344 Generate a `MATCH` expression appropriate for searching this table. 345 """ 346 return match(cls._meta.entity, term) 347 348 @classmethod 349 def rank(cls, *weights): 350 matchinfo = fn.matchinfo(cls._meta.entity, FTS3_MATCHINFO) 351 return fn.fts_rank(matchinfo, *weights) 352 353 @classmethod 354 def bm25(cls, *weights): 355 match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO) 356 return fn.fts_bm25(match_info, *weights) 357 358 @classmethod 359 def bm25f(cls, *weights): 360 match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO) 361 return fn.fts_bm25f(match_info, *weights) 362 363 @classmethod 364 def lucene(cls, *weights): 365 match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO) 366 return fn.fts_lucene(match_info, *weights) 367 368 @classmethod 369 def _search(cls, term, weights, with_score, score_alias, score_fn, 370 explicit_ordering): 371 if not weights: 372 rank = score_fn() 373 elif isinstance(weights, dict): 374 weight_args = [] 375 for field in cls._meta.sorted_fields: 376 # Attempt to get the specified weight of the field by looking 377 # it up using it's field instance followed by name. 378 field_weight = weights.get(field, weights.get(field.name, 1.0)) 379 weight_args.append(field_weight) 380 rank = score_fn(*weight_args) 381 else: 382 rank = score_fn(*weights) 383 384 selection = () 385 order_by = rank 386 if with_score: 387 selection = (cls, rank.alias(score_alias)) 388 if with_score and not explicit_ordering: 389 order_by = SQL(score_alias) 390 391 return (cls 392 .select(*selection) 393 .where(cls.match(term)) 394 .order_by(order_by)) 395 396 @classmethod 397 def search(cls, term, weights=None, with_score=False, score_alias='score', 398 explicit_ordering=False): 399 """Full-text search using selected `term`.""" 400 return cls._search( 401 term, 402 weights, 403 with_score, 404 score_alias, 405 cls.rank, 406 explicit_ordering) 407 408 @classmethod 409 def search_bm25(cls, term, weights=None, with_score=False, 410 score_alias='score', explicit_ordering=False): 411 """Full-text search for selected `term` using BM25 algorithm.""" 412 return cls._search( 413 term, 414 weights, 415 with_score, 416 score_alias, 417 cls.bm25, 418 explicit_ordering) 419 420 @classmethod 421 def search_bm25f(cls, term, weights=None, with_score=False, 422 score_alias='score', explicit_ordering=False): 423 """Full-text search for selected `term` using BM25 algorithm.""" 424 return cls._search( 425 term, 426 weights, 427 with_score, 428 score_alias, 429 cls.bm25f, 430 explicit_ordering) 431 432 @classmethod 433 def search_lucene(cls, term, weights=None, with_score=False, 434 score_alias='score', explicit_ordering=False): 435 """Full-text search for selected `term` using BM25 algorithm.""" 436 return cls._search( 437 term, 438 weights, 439 with_score, 440 score_alias, 441 cls.lucene, 442 explicit_ordering) 443 444 445_alphabet = 'abcdefghijklmnopqrstuvwxyz' 446_alphanum = (set('\t ,"(){}*:_+0123456789') | 447 set(_alphabet) | 448 set(_alphabet.upper()) | 449 set((chr(26),))) 450_invalid_ascii = set(chr(p) for p in range(128) if chr(p) not in _alphanum) 451_quote_re = re.compile(r'(?:[^\s"]|"(?:\\.|[^"])*")+') 452 453 454class FTS5Model(BaseFTSModel): 455 """ 456 Requires SQLite >= 3.9.0. 457 458 Table options: 459 460 content: table name of external content, or empty string for "contentless" 461 content_rowid: column name of external content primary key 462 prefix: integer(s). Ex: '2' or '2 3 4' 463 tokenize: porter, unicode61, ascii. Ex: 'porter unicode61' 464 465 The unicode tokenizer supports the following parameters: 466 467 * remove_diacritics (1 or 0, default is 1) 468 * tokenchars (string of characters, e.g. '-_' 469 * separators (string of characters) 470 471 Parameters are passed as alternating parameter name and value, so: 472 473 {'tokenize': "unicode61 remove_diacritics 0 tokenchars '-_'"} 474 475 Content-less tables: 476 477 If you don't need the full-text content in it's original form, you can 478 specify a content-less table. Searches and auxiliary functions will work 479 as usual, but the only values returned when SELECT-ing can be rowid. Also 480 content-less tables do not support UPDATE or DELETE. 481 482 External content tables: 483 484 You can set up triggers to sync these, e.g. 485 486 -- Create a table. And an external content fts5 table to index it. 487 CREATE TABLE tbl(a INTEGER PRIMARY KEY, b); 488 CREATE VIRTUAL TABLE ft USING fts5(b, content='tbl', content_rowid='a'); 489 490 -- Triggers to keep the FTS index up to date. 491 CREATE TRIGGER tbl_ai AFTER INSERT ON tbl BEGIN 492 INSERT INTO ft(rowid, b) VALUES (new.a, new.b); 493 END; 494 CREATE TRIGGER tbl_ad AFTER DELETE ON tbl BEGIN 495 INSERT INTO ft(fts_idx, rowid, b) VALUES('delete', old.a, old.b); 496 END; 497 CREATE TRIGGER tbl_au AFTER UPDATE ON tbl BEGIN 498 INSERT INTO ft(fts_idx, rowid, b) VALUES('delete', old.a, old.b); 499 INSERT INTO ft(rowid, b) VALUES (new.a, new.b); 500 END; 501 502 Built-in auxiliary functions: 503 504 * bm25(tbl[, weight_0, ... weight_n]) 505 * highlight(tbl, col_idx, prefix, suffix) 506 * snippet(tbl, col_idx, prefix, suffix, ?, max_tokens) 507 """ 508 # FTS5 does not support declared primary keys, but we can use the 509 # implicit rowid. 510 rowid = RowIDField() 511 512 class Meta: 513 extension_module = 'fts5' 514 515 _error_messages = { 516 'field_type': ('Besides the implicit `rowid` column, all columns must ' 517 'be instances of SearchField'), 518 'index': 'Secondary indexes are not supported for FTS5 models', 519 'pk': 'FTS5 models must use the default `rowid` primary key', 520 } 521 522 @classmethod 523 def validate_model(cls): 524 # Perform FTS5-specific validation and options post-processing. 525 if cls._meta.primary_key.name != 'rowid': 526 raise ImproperlyConfigured(cls._error_messages['pk']) 527 for field in cls._meta.fields.values(): 528 if not isinstance(field, (SearchField, RowIDField)): 529 raise ImproperlyConfigured(cls._error_messages['field_type']) 530 if cls._meta.indexes: 531 raise ImproperlyConfigured(cls._error_messages['index']) 532 533 @classmethod 534 def fts5_installed(cls): 535 if sqlite3.sqlite_version_info[:3] < FTS5_MIN_SQLITE_VERSION: 536 return False 537 538 # Test in-memory DB to determine if the FTS5 extension is installed. 539 tmp_db = sqlite3.connect(':memory:') 540 try: 541 tmp_db.execute('CREATE VIRTUAL TABLE fts5test USING fts5 (data);') 542 except: 543 try: 544 tmp_db.enable_load_extension(True) 545 tmp_db.load_extension('fts5') 546 except: 547 return False 548 else: 549 cls._meta.database.load_extension('fts5') 550 finally: 551 tmp_db.close() 552 553 return True 554 555 @staticmethod 556 def validate_query(query): 557 """ 558 Simple helper function to indicate whether a search query is a 559 valid FTS5 query. Note: this simply looks at the characters being 560 used, and is not guaranteed to catch all problematic queries. 561 """ 562 tokens = _quote_re.findall(query) 563 for token in tokens: 564 if token.startswith('"') and token.endswith('"'): 565 continue 566 if set(token) & _invalid_ascii: 567 return False 568 return True 569 570 @staticmethod 571 def clean_query(query, replace=chr(26)): 572 """ 573 Clean a query of invalid tokens. 574 """ 575 accum = [] 576 any_invalid = False 577 tokens = _quote_re.findall(query) 578 for token in tokens: 579 if token.startswith('"') and token.endswith('"'): 580 accum.append(token) 581 continue 582 token_set = set(token) 583 invalid_for_token = token_set & _invalid_ascii 584 if invalid_for_token: 585 any_invalid = True 586 for c in invalid_for_token: 587 token = token.replace(c, replace) 588 accum.append(token) 589 590 if any_invalid: 591 return ' '.join(accum) 592 return query 593 594 @classmethod 595 def match(cls, term): 596 """ 597 Generate a `MATCH` expression appropriate for searching this table. 598 """ 599 return match(cls._meta.entity, term) 600 601 @classmethod 602 def rank(cls, *args): 603 return cls.bm25(*args) if args else SQL('rank') 604 605 @classmethod 606 def bm25(cls, *weights): 607 return fn.bm25(cls._meta.entity, *weights) 608 609 @classmethod 610 def search(cls, term, weights=None, with_score=False, score_alias='score', 611 explicit_ordering=False): 612 """Full-text search using selected `term`.""" 613 return cls.search_bm25( 614 FTS5Model.clean_query(term), 615 weights, 616 with_score, 617 score_alias, 618 explicit_ordering) 619 620 @classmethod 621 def search_bm25(cls, term, weights=None, with_score=False, 622 score_alias='score', explicit_ordering=False): 623 """Full-text search using selected `term`.""" 624 if not weights: 625 rank = SQL('rank') 626 elif isinstance(weights, dict): 627 weight_args = [] 628 for field in cls._meta.sorted_fields: 629 if isinstance(field, SearchField) and not field.unindexed: 630 weight_args.append( 631 weights.get(field, weights.get(field.name, 1.0))) 632 rank = fn.bm25(cls._meta.entity, *weight_args) 633 else: 634 rank = fn.bm25(cls._meta.entity, *weights) 635 636 selection = () 637 order_by = rank 638 if with_score: 639 selection = (cls, rank.alias(score_alias)) 640 if with_score and not explicit_ordering: 641 order_by = SQL(score_alias) 642 643 return (cls 644 .select(*selection) 645 .where(cls.match(FTS5Model.clean_query(term))) 646 .order_by(order_by)) 647 648 @classmethod 649 def _fts_cmd_sql(cls, cmd, **extra_params): 650 tbl = cls._meta.entity 651 columns = [tbl] 652 values = [cmd] 653 for key, value in extra_params.items(): 654 columns.append(Entity(key)) 655 values.append(value) 656 657 return NodeList(( 658 SQL('INSERT INTO'), 659 cls._meta.entity, 660 EnclosedNodeList(columns), 661 SQL('VALUES'), 662 EnclosedNodeList(values))) 663 664 @classmethod 665 def _fts_cmd(cls, cmd, **extra_params): 666 query = cls._fts_cmd_sql(cmd, **extra_params) 667 return cls._meta.database.execute(query) 668 669 @classmethod 670 def automerge(cls, level): 671 if not (0 <= level <= 16): 672 raise ValueError('level must be between 0 and 16') 673 return cls._fts_cmd('automerge', rank=level) 674 675 @classmethod 676 def merge(cls, npages): 677 return cls._fts_cmd('merge', rank=npages) 678 679 @classmethod 680 def set_pgsz(cls, pgsz): 681 return cls._fts_cmd('pgsz', rank=pgsz) 682 683 @classmethod 684 def set_rank(cls, rank_expression): 685 return cls._fts_cmd('rank', rank=rank_expression) 686 687 @classmethod 688 def delete_all(cls): 689 return cls._fts_cmd('delete-all') 690 691 @classmethod 692 def VocabModel(cls, table_type='row', table=None): 693 if table_type not in ('row', 'col', 'instance'): 694 raise ValueError('table_type must be either "row", "col" or ' 695 '"instance".') 696 697 attr = '_vocab_model_%s' % table_type 698 699 if not hasattr(cls, attr): 700 class Meta: 701 database = cls._meta.database 702 table_name = table or cls._meta.table_name + '_v' 703 extension_module = fn.fts5vocab( 704 cls._meta.entity, 705 SQL(table_type)) 706 707 attrs = { 708 'term': VirtualField(TextField), 709 'doc': IntegerField(), 710 'cnt': IntegerField(), 711 'rowid': RowIDField(), 712 'Meta': Meta, 713 } 714 if table_type == 'col': 715 attrs['col'] = VirtualField(TextField) 716 elif table_type == 'instance': 717 attrs['offset'] = VirtualField(IntegerField) 718 719 class_name = '%sVocab' % cls.__name__ 720 setattr(cls, attr, type(class_name, (VirtualModel,), attrs)) 721 722 return getattr(cls, attr) 723 724 725def ClosureTable(model_class, foreign_key=None, referencing_class=None, 726 referencing_key=None): 727 """Model factory for the transitive closure extension.""" 728 if referencing_class is None: 729 referencing_class = model_class 730 731 if foreign_key is None: 732 for field_obj in model_class._meta.refs: 733 if field_obj.rel_model is model_class: 734 foreign_key = field_obj 735 break 736 else: 737 raise ValueError('Unable to find self-referential foreign key.') 738 739 source_key = model_class._meta.primary_key 740 if referencing_key is None: 741 referencing_key = source_key 742 743 class BaseClosureTable(VirtualModel): 744 depth = VirtualField(IntegerField) 745 id = VirtualField(IntegerField) 746 idcolumn = VirtualField(TextField) 747 parentcolumn = VirtualField(TextField) 748 root = VirtualField(IntegerField) 749 tablename = VirtualField(TextField) 750 751 class Meta: 752 extension_module = 'transitive_closure' 753 754 @classmethod 755 def descendants(cls, node, depth=None, include_node=False): 756 query = (model_class 757 .select(model_class, cls.depth.alias('depth')) 758 .join(cls, on=(source_key == cls.id)) 759 .where(cls.root == node) 760 .objects()) 761 if depth is not None: 762 query = query.where(cls.depth == depth) 763 elif not include_node: 764 query = query.where(cls.depth > 0) 765 return query 766 767 @classmethod 768 def ancestors(cls, node, depth=None, include_node=False): 769 query = (model_class 770 .select(model_class, cls.depth.alias('depth')) 771 .join(cls, on=(source_key == cls.root)) 772 .where(cls.id == node) 773 .objects()) 774 if depth: 775 query = query.where(cls.depth == depth) 776 elif not include_node: 777 query = query.where(cls.depth > 0) 778 return query 779 780 @classmethod 781 def siblings(cls, node, include_node=False): 782 if referencing_class is model_class: 783 # self-join 784 fk_value = node.__data__.get(foreign_key.name) 785 query = model_class.select().where(foreign_key == fk_value) 786 else: 787 # siblings as given in reference_class 788 siblings = (referencing_class 789 .select(referencing_key) 790 .join(cls, on=(foreign_key == cls.root)) 791 .where((cls.id == node) & (cls.depth == 1))) 792 793 # the according models 794 query = (model_class 795 .select() 796 .where(source_key << siblings) 797 .objects()) 798 799 if not include_node: 800 query = query.where(source_key != node) 801 802 return query 803 804 class Meta: 805 database = referencing_class._meta.database 806 options = { 807 'tablename': referencing_class._meta.table_name, 808 'idcolumn': referencing_key.column_name, 809 'parentcolumn': foreign_key.column_name} 810 primary_key = False 811 812 name = '%sClosure' % model_class.__name__ 813 return type(name, (BaseClosureTable,), {'Meta': Meta}) 814 815 816class LSMTable(VirtualModel): 817 class Meta: 818 extension_module = 'lsm1' 819 filename = None 820 821 @classmethod 822 def clean_options(cls, options): 823 filename = cls._meta.filename 824 if not filename: 825 raise ValueError('LSM1 extension requires that you specify a ' 826 'filename for the LSM database.') 827 else: 828 if len(filename) >= 2 and filename[0] != '"': 829 filename = '"%s"' % filename 830 if not cls._meta.primary_key: 831 raise ValueError('LSM1 models must specify a primary-key field.') 832 833 key = cls._meta.primary_key 834 if isinstance(key, AutoField): 835 raise ValueError('LSM1 models must explicitly declare a primary ' 836 'key field.') 837 if not isinstance(key, (TextField, BlobField, IntegerField)): 838 raise ValueError('LSM1 key must be a TextField, BlobField, or ' 839 'IntegerField.') 840 key._hidden = True 841 if isinstance(key, IntegerField): 842 data_type = 'UINT' 843 elif isinstance(key, BlobField): 844 data_type = 'BLOB' 845 else: 846 data_type = 'TEXT' 847 cls._meta.prefix_arguments = [filename, '"%s"' % key.name, data_type] 848 849 # Does the key map to a scalar value, or a tuple of values? 850 if len(cls._meta.sorted_fields) == 2: 851 cls._meta._value_field = cls._meta.sorted_fields[1] 852 else: 853 cls._meta._value_field = None 854 855 return options 856 857 @classmethod 858 def load_extension(cls, path='lsm.so'): 859 cls._meta.database.load_extension(path) 860 861 @staticmethod 862 def slice_to_expr(key, idx): 863 if idx.start is not None and idx.stop is not None: 864 return key.between(idx.start, idx.stop) 865 elif idx.start is not None: 866 return key >= idx.start 867 elif idx.stop is not None: 868 return key <= idx.stop 869 870 @staticmethod 871 def _apply_lookup_to_query(query, key, lookup): 872 if isinstance(lookup, slice): 873 expr = LSMTable.slice_to_expr(key, lookup) 874 if expr is not None: 875 query = query.where(expr) 876 return query, False 877 elif isinstance(lookup, Expression): 878 return query.where(lookup), False 879 else: 880 return query.where(key == lookup), True 881 882 @classmethod 883 def get_by_id(cls, pk): 884 query, is_single = cls._apply_lookup_to_query( 885 cls.select().namedtuples(), 886 cls._meta.primary_key, 887 pk) 888 889 if is_single: 890 try: 891 row = query.get() 892 except cls.DoesNotExist: 893 raise KeyError(pk) 894 return row[1] if cls._meta._value_field is not None else row 895 else: 896 return query 897 898 @classmethod 899 def set_by_id(cls, key, value): 900 if cls._meta._value_field is not None: 901 data = {cls._meta._value_field: value} 902 elif isinstance(value, tuple): 903 data = {} 904 for field, fval in zip(cls._meta.sorted_fields[1:], value): 905 data[field] = fval 906 elif isinstance(value, dict): 907 data = value 908 elif isinstance(value, cls): 909 data = value.__dict__ 910 data[cls._meta.primary_key] = key 911 cls.replace(data).execute() 912 913 @classmethod 914 def delete_by_id(cls, pk): 915 query, is_single = cls._apply_lookup_to_query( 916 cls.delete(), 917 cls._meta.primary_key, 918 pk) 919 return query.execute() 920 921 922OP.MATCH = 'MATCH' 923 924def _sqlite_regexp(regex, value): 925 return re.search(regex, value) is not None 926 927 928class SqliteExtDatabase(SqliteDatabase): 929 def __init__(self, database, c_extensions=None, rank_functions=True, 930 hash_functions=False, regexp_function=False, 931 bloomfilter=False, json_contains=False, *args, **kwargs): 932 super(SqliteExtDatabase, self).__init__(database, *args, **kwargs) 933 self._row_factory = None 934 935 if c_extensions and not CYTHON_SQLITE_EXTENSIONS: 936 raise ImproperlyConfigured('SqliteExtDatabase initialized with ' 937 'C extensions, but shared library was ' 938 'not found!') 939 prefer_c = CYTHON_SQLITE_EXTENSIONS and (c_extensions is not False) 940 if rank_functions: 941 if prefer_c: 942 register_rank_functions(self) 943 else: 944 self.register_function(bm25, 'fts_bm25') 945 self.register_function(rank, 'fts_rank') 946 self.register_function(bm25, 'fts_bm25f') # Fall back to bm25. 947 self.register_function(bm25, 'fts_lucene') 948 if hash_functions: 949 if not prefer_c: 950 raise ValueError('C extension required to register hash ' 951 'functions.') 952 register_hash_functions(self) 953 if regexp_function: 954 self.register_function(_sqlite_regexp, 'regexp', 2) 955 if bloomfilter: 956 if not prefer_c: 957 raise ValueError('C extension required to use bloomfilter.') 958 register_bloomfilter(self) 959 if json_contains: 960 self.register_function(_json_contains, 'json_contains') 961 962 self._c_extensions = prefer_c 963 964 def _add_conn_hooks(self, conn): 965 super(SqliteExtDatabase, self)._add_conn_hooks(conn) 966 if self._row_factory: 967 conn.row_factory = self._row_factory 968 969 def row_factory(self, fn): 970 self._row_factory = fn 971 972 973if CYTHON_SQLITE_EXTENSIONS: 974 SQLITE_STATUS_MEMORY_USED = 0 975 SQLITE_STATUS_PAGECACHE_USED = 1 976 SQLITE_STATUS_PAGECACHE_OVERFLOW = 2 977 SQLITE_STATUS_SCRATCH_USED = 3 978 SQLITE_STATUS_SCRATCH_OVERFLOW = 4 979 SQLITE_STATUS_MALLOC_SIZE = 5 980 SQLITE_STATUS_PARSER_STACK = 6 981 SQLITE_STATUS_PAGECACHE_SIZE = 7 982 SQLITE_STATUS_SCRATCH_SIZE = 8 983 SQLITE_STATUS_MALLOC_COUNT = 9 984 SQLITE_DBSTATUS_LOOKASIDE_USED = 0 985 SQLITE_DBSTATUS_CACHE_USED = 1 986 SQLITE_DBSTATUS_SCHEMA_USED = 2 987 SQLITE_DBSTATUS_STMT_USED = 3 988 SQLITE_DBSTATUS_LOOKASIDE_HIT = 4 989 SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE = 5 990 SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL = 6 991 SQLITE_DBSTATUS_CACHE_HIT = 7 992 SQLITE_DBSTATUS_CACHE_MISS = 8 993 SQLITE_DBSTATUS_CACHE_WRITE = 9 994 SQLITE_DBSTATUS_DEFERRED_FKS = 10 995 #SQLITE_DBSTATUS_CACHE_USED_SHARED = 11 996 997 def __status__(flag, return_highwater=False): 998 """ 999 Expose a sqlite3_status() call for a particular flag as a property of 1000 the Database object. 1001 """ 1002 def getter(self): 1003 result = sqlite_get_status(flag) 1004 return result[1] if return_highwater else result 1005 return property(getter) 1006 1007 def __dbstatus__(flag, return_highwater=False, return_current=False): 1008 """ 1009 Expose a sqlite3_dbstatus() call for a particular flag as a property of 1010 the Database instance. Unlike sqlite3_status(), the dbstatus properties 1011 pertain to the current connection. 1012 """ 1013 def getter(self): 1014 if self._state.conn is None: 1015 raise ImproperlyConfigured('database connection not opened.') 1016 result = sqlite_get_db_status(self._state.conn, flag) 1017 if return_current: 1018 return result[0] 1019 return result[1] if return_highwater else result 1020 return property(getter) 1021 1022 class CSqliteExtDatabase(SqliteExtDatabase): 1023 def __init__(self, *args, **kwargs): 1024 self._conn_helper = None 1025 self._commit_hook = self._rollback_hook = self._update_hook = None 1026 self._replace_busy_handler = False 1027 super(CSqliteExtDatabase, self).__init__(*args, **kwargs) 1028 1029 def init(self, database, replace_busy_handler=False, **kwargs): 1030 super(CSqliteExtDatabase, self).init(database, **kwargs) 1031 self._replace_busy_handler = replace_busy_handler 1032 1033 def _close(self, conn): 1034 if self._commit_hook: 1035 self._conn_helper.set_commit_hook(None) 1036 if self._rollback_hook: 1037 self._conn_helper.set_rollback_hook(None) 1038 if self._update_hook: 1039 self._conn_helper.set_update_hook(None) 1040 return super(CSqliteExtDatabase, self)._close(conn) 1041 1042 def _add_conn_hooks(self, conn): 1043 super(CSqliteExtDatabase, self)._add_conn_hooks(conn) 1044 self._conn_helper = ConnectionHelper(conn) 1045 if self._commit_hook is not None: 1046 self._conn_helper.set_commit_hook(self._commit_hook) 1047 if self._rollback_hook is not None: 1048 self._conn_helper.set_rollback_hook(self._rollback_hook) 1049 if self._update_hook is not None: 1050 self._conn_helper.set_update_hook(self._update_hook) 1051 if self._replace_busy_handler: 1052 timeout = self._timeout or 5 1053 self._conn_helper.set_busy_handler(timeout * 1000) 1054 1055 def on_commit(self, fn): 1056 self._commit_hook = fn 1057 if not self.is_closed(): 1058 self._conn_helper.set_commit_hook(fn) 1059 return fn 1060 1061 def on_rollback(self, fn): 1062 self._rollback_hook = fn 1063 if not self.is_closed(): 1064 self._conn_helper.set_rollback_hook(fn) 1065 return fn 1066 1067 def on_update(self, fn): 1068 self._update_hook = fn 1069 if not self.is_closed(): 1070 self._conn_helper.set_update_hook(fn) 1071 return fn 1072 1073 def changes(self): 1074 return self._conn_helper.changes() 1075 1076 @property 1077 def last_insert_rowid(self): 1078 return self._conn_helper.last_insert_rowid() 1079 1080 @property 1081 def autocommit(self): 1082 return self._conn_helper.autocommit() 1083 1084 def backup(self, destination, pages=None, name=None, progress=None): 1085 return backup(self.connection(), destination.connection(), 1086 pages=pages, name=name, progress=progress) 1087 1088 def backup_to_file(self, filename, pages=None, name=None, 1089 progress=None): 1090 return backup_to_file(self.connection(), filename, pages=pages, 1091 name=name, progress=progress) 1092 1093 def blob_open(self, table, column, rowid, read_only=False): 1094 return Blob(self, table, column, rowid, read_only) 1095 1096 # Status properties. 1097 memory_used = __status__(SQLITE_STATUS_MEMORY_USED) 1098 malloc_size = __status__(SQLITE_STATUS_MALLOC_SIZE, True) 1099 malloc_count = __status__(SQLITE_STATUS_MALLOC_COUNT) 1100 pagecache_used = __status__(SQLITE_STATUS_PAGECACHE_USED) 1101 pagecache_overflow = __status__(SQLITE_STATUS_PAGECACHE_OVERFLOW) 1102 pagecache_size = __status__(SQLITE_STATUS_PAGECACHE_SIZE, True) 1103 scratch_used = __status__(SQLITE_STATUS_SCRATCH_USED) 1104 scratch_overflow = __status__(SQLITE_STATUS_SCRATCH_OVERFLOW) 1105 scratch_size = __status__(SQLITE_STATUS_SCRATCH_SIZE, True) 1106 1107 # Connection status properties. 1108 lookaside_used = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_USED) 1109 lookaside_hit = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_HIT, True) 1110 lookaside_miss = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE, 1111 True) 1112 lookaside_miss_full = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL, 1113 True) 1114 cache_used = __dbstatus__(SQLITE_DBSTATUS_CACHE_USED, False, True) 1115 #cache_used_shared = __dbstatus__(SQLITE_DBSTATUS_CACHE_USED_SHARED, 1116 # False, True) 1117 schema_used = __dbstatus__(SQLITE_DBSTATUS_SCHEMA_USED, False, True) 1118 statement_used = __dbstatus__(SQLITE_DBSTATUS_STMT_USED, False, True) 1119 cache_hit = __dbstatus__(SQLITE_DBSTATUS_CACHE_HIT, False, True) 1120 cache_miss = __dbstatus__(SQLITE_DBSTATUS_CACHE_MISS, False, True) 1121 cache_write = __dbstatus__(SQLITE_DBSTATUS_CACHE_WRITE, False, True) 1122 1123 1124def match(lhs, rhs): 1125 return Expression(lhs, OP.MATCH, rhs) 1126 1127def _parse_match_info(buf): 1128 # See http://sqlite.org/fts3.html#matchinfo 1129 bufsize = len(buf) # Length in bytes. 1130 return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)] 1131 1132def get_weights(ncol, raw_weights): 1133 if not raw_weights: 1134 return [1] * ncol 1135 else: 1136 weights = [0] * ncol 1137 for i, weight in enumerate(raw_weights): 1138 weights[i] = weight 1139 return weights 1140 1141# Ranking implementation, which parse matchinfo. 1142def rank(raw_match_info, *raw_weights): 1143 # Handle match_info called w/default args 'pcx' - based on the example rank 1144 # function http://sqlite.org/fts3.html#appendix_a 1145 match_info = _parse_match_info(raw_match_info) 1146 score = 0.0 1147 1148 p, c = match_info[:2] 1149 weights = get_weights(c, raw_weights) 1150 1151 # matchinfo X value corresponds to, for each phrase in the search query, a 1152 # list of 3 values for each column in the search table. 1153 # So if we have a two-phrase search query and three columns of data, the 1154 # following would be the layout: 1155 # p0 : c0=[0, 1, 2], c1=[3, 4, 5], c2=[6, 7, 8] 1156 # p1 : c0=[9, 10, 11], c1=[12, 13, 14], c2=[15, 16, 17] 1157 for phrase_num in range(p): 1158 phrase_info_idx = 2 + (phrase_num * c * 3) 1159 for col_num in range(c): 1160 weight = weights[col_num] 1161 if not weight: 1162 continue 1163 1164 col_idx = phrase_info_idx + (col_num * 3) 1165 1166 # The idea is that we count the number of times the phrase appears 1167 # in this column of the current row, compared to how many times it 1168 # appears in this column across all rows. The ratio of these values 1169 # provides a rough way to score based on "high value" terms. 1170 row_hits = match_info[col_idx] 1171 all_rows_hits = match_info[col_idx + 1] 1172 if row_hits > 0: 1173 score += weight * (float(row_hits) / all_rows_hits) 1174 1175 return -score 1176 1177# Okapi BM25 ranking implementation (FTS4 only). 1178def bm25(raw_match_info, *args): 1179 """ 1180 Usage: 1181 1182 # Format string *must* be pcnalx 1183 # Second parameter to bm25 specifies the index of the column, on 1184 # the table being queries. 1185 bm25(matchinfo(document_tbl, 'pcnalx'), 1) AS rank 1186 """ 1187 match_info = _parse_match_info(raw_match_info) 1188 K = 1.2 1189 B = 0.75 1190 score = 0.0 1191 1192 P_O, C_O, N_O, A_O = range(4) # Offsets into the matchinfo buffer. 1193 term_count = match_info[P_O] # n 1194 col_count = match_info[C_O] 1195 total_docs = match_info[N_O] # N 1196 L_O = A_O + col_count 1197 X_O = L_O + col_count 1198 1199 # Worked example of pcnalx for two columns and two phrases, 100 docs total. 1200 # { 1201 # p = 2 1202 # c = 2 1203 # n = 100 1204 # a0 = 4 -- avg number of tokens for col0, e.g. title 1205 # a1 = 40 -- avg number of tokens for col1, e.g. body 1206 # l0 = 5 -- curr doc has 5 tokens in col0 1207 # l1 = 30 -- curr doc has 30 tokens in col1 1208 # 1209 # x000 -- hits this row for phrase0, col0 1210 # x001 -- hits all rows for phrase0, col0 1211 # x002 -- rows with phrase0 in col0 at least once 1212 # 1213 # x010 -- hits this row for phrase0, col1 1214 # x011 -- hits all rows for phrase0, col1 1215 # x012 -- rows with phrase0 in col1 at least once 1216 # 1217 # x100 -- hits this row for phrase1, col0 1218 # x101 -- hits all rows for phrase1, col0 1219 # x102 -- rows with phrase1 in col0 at least once 1220 # 1221 # x110 -- hits this row for phrase1, col1 1222 # x111 -- hits all rows for phrase1, col1 1223 # x112 -- rows with phrase1 in col1 at least once 1224 # } 1225 1226 weights = get_weights(col_count, args) 1227 1228 for i in range(term_count): 1229 for j in range(col_count): 1230 weight = weights[j] 1231 if weight == 0: 1232 continue 1233 1234 x = X_O + (3 * (j + i * col_count)) 1235 term_frequency = float(match_info[x]) # f(qi, D) 1236 docs_with_term = float(match_info[x + 2]) # n(qi) 1237 1238 # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) ) 1239 idf = math.log( 1240 (total_docs - docs_with_term + 0.5) / 1241 (docs_with_term + 0.5)) 1242 if idf <= 0.0: 1243 idf = 1e-6 1244 1245 doc_length = float(match_info[L_O + j]) # |D| 1246 avg_length = float(match_info[A_O + j]) or 1. # avgdl 1247 ratio = doc_length / avg_length 1248 1249 num = term_frequency * (K + 1.0) 1250 b_part = 1.0 - B + (B * ratio) 1251 denom = term_frequency + (K * b_part) 1252 1253 pc_score = idf * (num / denom) 1254 score += (pc_score * weight) 1255 1256 return -score 1257 1258 1259def _json_contains(src_json, obj_json): 1260 stack = [] 1261 try: 1262 stack.append((json.loads(obj_json), json.loads(src_json))) 1263 except: 1264 # Invalid JSON! 1265 return False 1266 1267 while stack: 1268 obj, src = stack.pop() 1269 if isinstance(src, dict): 1270 if isinstance(obj, dict): 1271 for key in obj: 1272 if key not in src: 1273 return False 1274 stack.append((obj[key], src[key])) 1275 elif isinstance(obj, list): 1276 for item in obj: 1277 if item not in src: 1278 return False 1279 elif obj not in src: 1280 return False 1281 elif isinstance(src, list): 1282 if isinstance(obj, dict): 1283 return False 1284 elif isinstance(obj, list): 1285 try: 1286 for i in range(len(obj)): 1287 stack.append((obj[i], src[i])) 1288 except IndexError: 1289 return False 1290 elif obj not in src: 1291 return False 1292 elif obj != src: 1293 return False 1294 return True 1295