1import json
2import math
3import re
4import struct
5import sys
6
7from peewee import *
8from peewee import ColumnBase
9from peewee import EnclosedNodeList
10from peewee import Entity
11from peewee import Expression
12from peewee import Node
13from peewee import NodeList
14from peewee import OP
15from peewee import VirtualField
16from peewee import merge_dict
17from peewee import sqlite3
18try:
19    from playhouse._sqlite_ext import (
20        backup,
21        backup_to_file,
22        Blob,
23        ConnectionHelper,
24        register_bloomfilter,
25        register_hash_functions,
26        register_rank_functions,
27        sqlite_get_db_status,
28        sqlite_get_status,
29        TableFunction,
30        ZeroBlob,
31    )
32    CYTHON_SQLITE_EXTENSIONS = True
33except ImportError:
34    CYTHON_SQLITE_EXTENSIONS = False
35
36
37if sys.version_info[0] == 3:
38    basestring = str
39
40
41FTS3_MATCHINFO = 'pcx'
42FTS4_MATCHINFO = 'pcnalx'
43if sqlite3 is not None:
44    FTS_VERSION = 4 if sqlite3.sqlite_version_info[:3] >= (3, 7, 4) else 3
45else:
46    FTS_VERSION = 3
47
48FTS5_MIN_SQLITE_VERSION = (3, 9, 0)
49
50
51class RowIDField(AutoField):
52    auto_increment = True
53    column_name = name = required_name = 'rowid'
54
55    def bind(self, model, name, *args):
56        if name != self.required_name:
57            raise ValueError('%s must be named "%s".' %
58                             (type(self), self.required_name))
59        super(RowIDField, self).bind(model, name, *args)
60
61
62class DocIDField(RowIDField):
63    column_name = name = required_name = 'docid'
64
65
66class AutoIncrementField(AutoField):
67    def ddl(self, ctx):
68        node_list = super(AutoIncrementField, self).ddl(ctx)
69        return NodeList((node_list, SQL('AUTOINCREMENT')))
70
71
72class TDecimalField(DecimalField):
73    field_type = 'TEXT'
74    def get_modifiers(self): pass
75
76
77class JSONPath(ColumnBase):
78    def __init__(self, field, path=None):
79        super(JSONPath, self).__init__()
80        self._field = field
81        self._path = path or ()
82
83    @property
84    def path(self):
85        return Value('$%s' % ''.join(self._path))
86
87    def __getitem__(self, idx):
88        if isinstance(idx, int):
89            item = '[%s]' % idx
90        else:
91            item = '.%s' % idx
92        return JSONPath(self._field, self._path + (item,))
93
94    def set(self, value, as_json=None):
95        if as_json or isinstance(value, (list, dict)):
96            value = fn.json(self._field._json_dumps(value))
97        return fn.json_set(self._field, self.path, value)
98
99    def update(self, value):
100        return self.set(fn.json_patch(self, self._field._json_dumps(value)))
101
102    def remove(self):
103        return fn.json_remove(self._field, self.path)
104
105    def json_type(self):
106        return fn.json_type(self._field, self.path)
107
108    def length(self):
109        return fn.json_array_length(self._field, self.path)
110
111    def children(self):
112        return fn.json_each(self._field, self.path)
113
114    def tree(self):
115        return fn.json_tree(self._field, self.path)
116
117    def __sql__(self, ctx):
118        return ctx.sql(fn.json_extract(self._field, self.path)
119                       if self._path else self._field)
120
121
122class JSONField(TextField):
123    field_type = 'JSON'
124    unpack = False
125
126    def __init__(self, json_dumps=None, json_loads=None, **kwargs):
127        self._json_dumps = json_dumps or json.dumps
128        self._json_loads = json_loads or json.loads
129        super(JSONField, self).__init__(**kwargs)
130
131    def python_value(self, value):
132        if value is not None:
133            try:
134                return self._json_loads(value)
135            except (TypeError, ValueError):
136                return value
137
138    def db_value(self, value):
139        if value is not None:
140            if not isinstance(value, Node):
141                value = fn.json(self._json_dumps(value))
142            return value
143
144    def _e(op):
145        def inner(self, rhs):
146            if isinstance(rhs, (list, dict)):
147                rhs = Value(rhs, converter=self.db_value, unpack=False)
148            return Expression(self, op, rhs)
149        return inner
150    __eq__ = _e(OP.EQ)
151    __ne__ = _e(OP.NE)
152    __gt__ = _e(OP.GT)
153    __ge__ = _e(OP.GTE)
154    __lt__ = _e(OP.LT)
155    __le__ = _e(OP.LTE)
156    __hash__ = Field.__hash__
157
158    def __getitem__(self, item):
159        return JSONPath(self)[item]
160
161    def set(self, value, as_json=None):
162        return JSONPath(self).set(value, as_json)
163
164    def update(self, data):
165        return JSONPath(self).update(data)
166
167    def remove(self):
168        return JSONPath(self).remove()
169
170    def json_type(self):
171        return fn.json_type(self)
172
173    def length(self):
174        return fn.json_array_length(self)
175
176    def children(self):
177        """
178        Schema of `json_each` and `json_tree`:
179
180        key,
181        value,
182        type TEXT (object, array, string, etc),
183        atom (value for primitive/scalar types, NULL for array and object)
184        id INTEGER (unique identifier for element)
185        parent INTEGER (unique identifier of parent element or NULL)
186        fullkey TEXT (full path describing element)
187        path TEXT (path to the container of the current element)
188        json JSON hidden (1st input parameter to function)
189        root TEXT hidden (2nd input parameter, path at which to start)
190        """
191        return fn.json_each(self)
192
193    def tree(self):
194        return fn.json_tree(self)
195
196
197class SearchField(Field):
198    def __init__(self, unindexed=False, column_name=None, **k):
199        if k:
200            raise ValueError('SearchField does not accept these keyword '
201                             'arguments: %s.' % sorted(k))
202        super(SearchField, self).__init__(unindexed=unindexed,
203                                          column_name=column_name, null=True)
204
205    def match(self, term):
206        return match(self, term)
207
208
209class VirtualTableSchemaManager(SchemaManager):
210    def _create_virtual_table(self, safe=True, **options):
211        options = self.model.clean_options(
212            merge_dict(self.model._meta.options, options))
213
214        # Structure:
215        # CREATE VIRTUAL TABLE <model>
216        # USING <extension_module>
217        # ([prefix_arguments, ...] fields, ... [arguments, ...], [options...])
218        ctx = self._create_context()
219        ctx.literal('CREATE VIRTUAL TABLE ')
220        if safe:
221            ctx.literal('IF NOT EXISTS ')
222        (ctx
223         .sql(self.model)
224         .literal(' USING '))
225
226        ext_module = self.model._meta.extension_module
227        if isinstance(ext_module, Node):
228            return ctx.sql(ext_module)
229
230        ctx.sql(SQL(ext_module)).literal(' ')
231        arguments = []
232        meta = self.model._meta
233
234        if meta.prefix_arguments:
235            arguments.extend([SQL(a) for a in meta.prefix_arguments])
236
237        # Constraints, data-types, foreign and primary keys are all omitted.
238        for field in meta.sorted_fields:
239            if isinstance(field, (RowIDField)) or field._hidden:
240                continue
241            field_def = [Entity(field.column_name)]
242            if field.unindexed:
243                field_def.append(SQL('UNINDEXED'))
244            arguments.append(NodeList(field_def))
245
246        if meta.arguments:
247            arguments.extend([SQL(a) for a in meta.arguments])
248
249        if options:
250            arguments.extend(self._create_table_option_sql(options))
251        return ctx.sql(EnclosedNodeList(arguments))
252
253    def _create_table(self, safe=True, **options):
254        if issubclass(self.model, VirtualModel):
255            return self._create_virtual_table(safe, **options)
256
257        return super(VirtualTableSchemaManager, self)._create_table(
258            safe, **options)
259
260
261class VirtualModel(Model):
262    class Meta:
263        arguments = None
264        extension_module = None
265        prefix_arguments = None
266        primary_key = False
267        schema_manager_class = VirtualTableSchemaManager
268
269    @classmethod
270    def clean_options(cls, options):
271        return options
272
273
274class BaseFTSModel(VirtualModel):
275    @classmethod
276    def clean_options(cls, options):
277        content = options.get('content')
278        prefix = options.get('prefix')
279        tokenize = options.get('tokenize')
280
281        if isinstance(content, basestring) and content == '':
282            # Special-case content-less full-text search tables.
283            options['content'] = "''"
284        elif isinstance(content, Field):
285            # Special-case to ensure fields are fully-qualified.
286            options['content'] = Entity(content.model._meta.table_name,
287                                        content.column_name)
288
289        if prefix:
290            if isinstance(prefix, (list, tuple)):
291                prefix = ','.join([str(i) for i in prefix])
292            options['prefix'] = "'%s'" % prefix.strip("' ")
293
294        if tokenize and cls._meta.extension_module.lower() == 'fts5':
295            # Tokenizers need to be in quoted string for FTS5, but not for FTS3
296            # or FTS4.
297            options['tokenize'] = '"%s"' % tokenize
298
299        return options
300
301
302class FTSModel(BaseFTSModel):
303    """
304    VirtualModel class for creating tables that use either the FTS3 or FTS4
305    search extensions. Peewee automatically determines which version of the
306    FTS extension is supported and will use FTS4 if possible.
307    """
308    # FTS3/4 uses "docid" in the same way a normal table uses "rowid".
309    docid = DocIDField()
310
311    class Meta:
312        extension_module = 'FTS%s' % FTS_VERSION
313
314    @classmethod
315    def _fts_cmd(cls, cmd):
316        tbl = cls._meta.table_name
317        res = cls._meta.database.execute_sql(
318            "INSERT INTO %s(%s) VALUES('%s');" % (tbl, tbl, cmd))
319        return res.fetchone()
320
321    @classmethod
322    def optimize(cls):
323        return cls._fts_cmd('optimize')
324
325    @classmethod
326    def rebuild(cls):
327        return cls._fts_cmd('rebuild')
328
329    @classmethod
330    def integrity_check(cls):
331        return cls._fts_cmd('integrity-check')
332
333    @classmethod
334    def merge(cls, blocks=200, segments=8):
335        return cls._fts_cmd('merge=%s,%s' % (blocks, segments))
336
337    @classmethod
338    def automerge(cls, state=True):
339        return cls._fts_cmd('automerge=%s' % (state and '1' or '0'))
340
341    @classmethod
342    def match(cls, term):
343        """
344        Generate a `MATCH` expression appropriate for searching this table.
345        """
346        return match(cls._meta.entity, term)
347
348    @classmethod
349    def rank(cls, *weights):
350        matchinfo = fn.matchinfo(cls._meta.entity, FTS3_MATCHINFO)
351        return fn.fts_rank(matchinfo, *weights)
352
353    @classmethod
354    def bm25(cls, *weights):
355        match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO)
356        return fn.fts_bm25(match_info, *weights)
357
358    @classmethod
359    def bm25f(cls, *weights):
360        match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO)
361        return fn.fts_bm25f(match_info, *weights)
362
363    @classmethod
364    def lucene(cls, *weights):
365        match_info = fn.matchinfo(cls._meta.entity, FTS4_MATCHINFO)
366        return fn.fts_lucene(match_info, *weights)
367
368    @classmethod
369    def _search(cls, term, weights, with_score, score_alias, score_fn,
370                explicit_ordering):
371        if not weights:
372            rank = score_fn()
373        elif isinstance(weights, dict):
374            weight_args = []
375            for field in cls._meta.sorted_fields:
376                # Attempt to get the specified weight of the field by looking
377                # it up using it's field instance followed by name.
378                field_weight = weights.get(field, weights.get(field.name, 1.0))
379                weight_args.append(field_weight)
380            rank = score_fn(*weight_args)
381        else:
382            rank = score_fn(*weights)
383
384        selection = ()
385        order_by = rank
386        if with_score:
387            selection = (cls, rank.alias(score_alias))
388        if with_score and not explicit_ordering:
389            order_by = SQL(score_alias)
390
391        return (cls
392                .select(*selection)
393                .where(cls.match(term))
394                .order_by(order_by))
395
396    @classmethod
397    def search(cls, term, weights=None, with_score=False, score_alias='score',
398               explicit_ordering=False):
399        """Full-text search using selected `term`."""
400        return cls._search(
401            term,
402            weights,
403            with_score,
404            score_alias,
405            cls.rank,
406            explicit_ordering)
407
408    @classmethod
409    def search_bm25(cls, term, weights=None, with_score=False,
410                    score_alias='score', explicit_ordering=False):
411        """Full-text search for selected `term` using BM25 algorithm."""
412        return cls._search(
413            term,
414            weights,
415            with_score,
416            score_alias,
417            cls.bm25,
418            explicit_ordering)
419
420    @classmethod
421    def search_bm25f(cls, term, weights=None, with_score=False,
422                     score_alias='score', explicit_ordering=False):
423        """Full-text search for selected `term` using BM25 algorithm."""
424        return cls._search(
425            term,
426            weights,
427            with_score,
428            score_alias,
429            cls.bm25f,
430            explicit_ordering)
431
432    @classmethod
433    def search_lucene(cls, term, weights=None, with_score=False,
434                      score_alias='score', explicit_ordering=False):
435        """Full-text search for selected `term` using BM25 algorithm."""
436        return cls._search(
437            term,
438            weights,
439            with_score,
440            score_alias,
441            cls.lucene,
442            explicit_ordering)
443
444
445_alphabet = 'abcdefghijklmnopqrstuvwxyz'
446_alphanum = (set('\t ,"(){}*:_+0123456789') |
447             set(_alphabet) |
448             set(_alphabet.upper()) |
449             set((chr(26),)))
450_invalid_ascii = set(chr(p) for p in range(128) if chr(p) not in _alphanum)
451_quote_re = re.compile(r'(?:[^\s"]|"(?:\\.|[^"])*")+')
452
453
454class FTS5Model(BaseFTSModel):
455    """
456    Requires SQLite >= 3.9.0.
457
458    Table options:
459
460    content: table name of external content, or empty string for "contentless"
461    content_rowid: column name of external content primary key
462    prefix: integer(s). Ex: '2' or '2 3 4'
463    tokenize: porter, unicode61, ascii. Ex: 'porter unicode61'
464
465    The unicode tokenizer supports the following parameters:
466
467    * remove_diacritics (1 or 0, default is 1)
468    * tokenchars (string of characters, e.g. '-_'
469    * separators (string of characters)
470
471    Parameters are passed as alternating parameter name and value, so:
472
473    {'tokenize': "unicode61 remove_diacritics 0 tokenchars '-_'"}
474
475    Content-less tables:
476
477    If you don't need the full-text content in it's original form, you can
478    specify a content-less table. Searches and auxiliary functions will work
479    as usual, but the only values returned when SELECT-ing can be rowid. Also
480    content-less tables do not support UPDATE or DELETE.
481
482    External content tables:
483
484    You can set up triggers to sync these, e.g.
485
486    -- Create a table. And an external content fts5 table to index it.
487    CREATE TABLE tbl(a INTEGER PRIMARY KEY, b);
488    CREATE VIRTUAL TABLE ft USING fts5(b, content='tbl', content_rowid='a');
489
490    -- Triggers to keep the FTS index up to date.
491    CREATE TRIGGER tbl_ai AFTER INSERT ON tbl BEGIN
492      INSERT INTO ft(rowid, b) VALUES (new.a, new.b);
493    END;
494    CREATE TRIGGER tbl_ad AFTER DELETE ON tbl BEGIN
495      INSERT INTO ft(fts_idx, rowid, b) VALUES('delete', old.a, old.b);
496    END;
497    CREATE TRIGGER tbl_au AFTER UPDATE ON tbl BEGIN
498      INSERT INTO ft(fts_idx, rowid, b) VALUES('delete', old.a, old.b);
499      INSERT INTO ft(rowid, b) VALUES (new.a, new.b);
500    END;
501
502    Built-in auxiliary functions:
503
504    * bm25(tbl[, weight_0, ... weight_n])
505    * highlight(tbl, col_idx, prefix, suffix)
506    * snippet(tbl, col_idx, prefix, suffix, ?, max_tokens)
507    """
508    # FTS5 does not support declared primary keys, but we can use the
509    # implicit rowid.
510    rowid = RowIDField()
511
512    class Meta:
513        extension_module = 'fts5'
514
515    _error_messages = {
516        'field_type': ('Besides the implicit `rowid` column, all columns must '
517                       'be instances of SearchField'),
518        'index': 'Secondary indexes are not supported for FTS5 models',
519        'pk': 'FTS5 models must use the default `rowid` primary key',
520    }
521
522    @classmethod
523    def validate_model(cls):
524        # Perform FTS5-specific validation and options post-processing.
525        if cls._meta.primary_key.name != 'rowid':
526            raise ImproperlyConfigured(cls._error_messages['pk'])
527        for field in cls._meta.fields.values():
528            if not isinstance(field, (SearchField, RowIDField)):
529                raise ImproperlyConfigured(cls._error_messages['field_type'])
530        if cls._meta.indexes:
531            raise ImproperlyConfigured(cls._error_messages['index'])
532
533    @classmethod
534    def fts5_installed(cls):
535        if sqlite3.sqlite_version_info[:3] < FTS5_MIN_SQLITE_VERSION:
536            return False
537
538        # Test in-memory DB to determine if the FTS5 extension is installed.
539        tmp_db = sqlite3.connect(':memory:')
540        try:
541            tmp_db.execute('CREATE VIRTUAL TABLE fts5test USING fts5 (data);')
542        except:
543            try:
544                tmp_db.enable_load_extension(True)
545                tmp_db.load_extension('fts5')
546            except:
547                return False
548            else:
549                cls._meta.database.load_extension('fts5')
550        finally:
551            tmp_db.close()
552
553        return True
554
555    @staticmethod
556    def validate_query(query):
557        """
558        Simple helper function to indicate whether a search query is a
559        valid FTS5 query. Note: this simply looks at the characters being
560        used, and is not guaranteed to catch all problematic queries.
561        """
562        tokens = _quote_re.findall(query)
563        for token in tokens:
564            if token.startswith('"') and token.endswith('"'):
565                continue
566            if set(token) & _invalid_ascii:
567                return False
568        return True
569
570    @staticmethod
571    def clean_query(query, replace=chr(26)):
572        """
573        Clean a query of invalid tokens.
574        """
575        accum = []
576        any_invalid = False
577        tokens = _quote_re.findall(query)
578        for token in tokens:
579            if token.startswith('"') and token.endswith('"'):
580                accum.append(token)
581                continue
582            token_set = set(token)
583            invalid_for_token = token_set & _invalid_ascii
584            if invalid_for_token:
585                any_invalid = True
586                for c in invalid_for_token:
587                    token = token.replace(c, replace)
588            accum.append(token)
589
590        if any_invalid:
591            return ' '.join(accum)
592        return query
593
594    @classmethod
595    def match(cls, term):
596        """
597        Generate a `MATCH` expression appropriate for searching this table.
598        """
599        return match(cls._meta.entity, term)
600
601    @classmethod
602    def rank(cls, *args):
603        return cls.bm25(*args) if args else SQL('rank')
604
605    @classmethod
606    def bm25(cls, *weights):
607        return fn.bm25(cls._meta.entity, *weights)
608
609    @classmethod
610    def search(cls, term, weights=None, with_score=False, score_alias='score',
611               explicit_ordering=False):
612        """Full-text search using selected `term`."""
613        return cls.search_bm25(
614            FTS5Model.clean_query(term),
615            weights,
616            with_score,
617            score_alias,
618            explicit_ordering)
619
620    @classmethod
621    def search_bm25(cls, term, weights=None, with_score=False,
622                    score_alias='score', explicit_ordering=False):
623        """Full-text search using selected `term`."""
624        if not weights:
625            rank = SQL('rank')
626        elif isinstance(weights, dict):
627            weight_args = []
628            for field in cls._meta.sorted_fields:
629                if isinstance(field, SearchField) and not field.unindexed:
630                    weight_args.append(
631                        weights.get(field, weights.get(field.name, 1.0)))
632            rank = fn.bm25(cls._meta.entity, *weight_args)
633        else:
634            rank = fn.bm25(cls._meta.entity, *weights)
635
636        selection = ()
637        order_by = rank
638        if with_score:
639            selection = (cls, rank.alias(score_alias))
640        if with_score and not explicit_ordering:
641            order_by = SQL(score_alias)
642
643        return (cls
644                .select(*selection)
645                .where(cls.match(FTS5Model.clean_query(term)))
646                .order_by(order_by))
647
648    @classmethod
649    def _fts_cmd_sql(cls, cmd, **extra_params):
650        tbl = cls._meta.entity
651        columns = [tbl]
652        values = [cmd]
653        for key, value in extra_params.items():
654            columns.append(Entity(key))
655            values.append(value)
656
657        return NodeList((
658            SQL('INSERT INTO'),
659            cls._meta.entity,
660            EnclosedNodeList(columns),
661            SQL('VALUES'),
662            EnclosedNodeList(values)))
663
664    @classmethod
665    def _fts_cmd(cls, cmd, **extra_params):
666        query = cls._fts_cmd_sql(cmd, **extra_params)
667        return cls._meta.database.execute(query)
668
669    @classmethod
670    def automerge(cls, level):
671        if not (0 <= level <= 16):
672            raise ValueError('level must be between 0 and 16')
673        return cls._fts_cmd('automerge', rank=level)
674
675    @classmethod
676    def merge(cls, npages):
677        return cls._fts_cmd('merge', rank=npages)
678
679    @classmethod
680    def set_pgsz(cls, pgsz):
681        return cls._fts_cmd('pgsz', rank=pgsz)
682
683    @classmethod
684    def set_rank(cls, rank_expression):
685        return cls._fts_cmd('rank', rank=rank_expression)
686
687    @classmethod
688    def delete_all(cls):
689        return cls._fts_cmd('delete-all')
690
691    @classmethod
692    def VocabModel(cls, table_type='row', table=None):
693        if table_type not in ('row', 'col', 'instance'):
694            raise ValueError('table_type must be either "row", "col" or '
695                             '"instance".')
696
697        attr = '_vocab_model_%s' % table_type
698
699        if not hasattr(cls, attr):
700            class Meta:
701                database = cls._meta.database
702                table_name = table or cls._meta.table_name + '_v'
703                extension_module = fn.fts5vocab(
704                    cls._meta.entity,
705                    SQL(table_type))
706
707            attrs = {
708                'term': VirtualField(TextField),
709                'doc': IntegerField(),
710                'cnt': IntegerField(),
711                'rowid': RowIDField(),
712                'Meta': Meta,
713            }
714            if table_type == 'col':
715                attrs['col'] = VirtualField(TextField)
716            elif table_type == 'instance':
717                attrs['offset'] = VirtualField(IntegerField)
718
719            class_name = '%sVocab' % cls.__name__
720            setattr(cls, attr, type(class_name, (VirtualModel,), attrs))
721
722        return getattr(cls, attr)
723
724
725def ClosureTable(model_class, foreign_key=None, referencing_class=None,
726                 referencing_key=None):
727    """Model factory for the transitive closure extension."""
728    if referencing_class is None:
729        referencing_class = model_class
730
731    if foreign_key is None:
732        for field_obj in model_class._meta.refs:
733            if field_obj.rel_model is model_class:
734                foreign_key = field_obj
735                break
736        else:
737            raise ValueError('Unable to find self-referential foreign key.')
738
739    source_key = model_class._meta.primary_key
740    if referencing_key is None:
741        referencing_key = source_key
742
743    class BaseClosureTable(VirtualModel):
744        depth = VirtualField(IntegerField)
745        id = VirtualField(IntegerField)
746        idcolumn = VirtualField(TextField)
747        parentcolumn = VirtualField(TextField)
748        root = VirtualField(IntegerField)
749        tablename = VirtualField(TextField)
750
751        class Meta:
752            extension_module = 'transitive_closure'
753
754        @classmethod
755        def descendants(cls, node, depth=None, include_node=False):
756            query = (model_class
757                     .select(model_class, cls.depth.alias('depth'))
758                     .join(cls, on=(source_key == cls.id))
759                     .where(cls.root == node)
760                     .objects())
761            if depth is not None:
762                query = query.where(cls.depth == depth)
763            elif not include_node:
764                query = query.where(cls.depth > 0)
765            return query
766
767        @classmethod
768        def ancestors(cls, node, depth=None, include_node=False):
769            query = (model_class
770                     .select(model_class, cls.depth.alias('depth'))
771                     .join(cls, on=(source_key == cls.root))
772                     .where(cls.id == node)
773                     .objects())
774            if depth:
775                query = query.where(cls.depth == depth)
776            elif not include_node:
777                query = query.where(cls.depth > 0)
778            return query
779
780        @classmethod
781        def siblings(cls, node, include_node=False):
782            if referencing_class is model_class:
783                # self-join
784                fk_value = node.__data__.get(foreign_key.name)
785                query = model_class.select().where(foreign_key == fk_value)
786            else:
787                # siblings as given in reference_class
788                siblings = (referencing_class
789                            .select(referencing_key)
790                            .join(cls, on=(foreign_key == cls.root))
791                            .where((cls.id == node) & (cls.depth == 1)))
792
793                # the according models
794                query = (model_class
795                         .select()
796                         .where(source_key << siblings)
797                         .objects())
798
799            if not include_node:
800                query = query.where(source_key != node)
801
802            return query
803
804    class Meta:
805        database = referencing_class._meta.database
806        options = {
807            'tablename': referencing_class._meta.table_name,
808            'idcolumn': referencing_key.column_name,
809            'parentcolumn': foreign_key.column_name}
810        primary_key = False
811
812    name = '%sClosure' % model_class.__name__
813    return type(name, (BaseClosureTable,), {'Meta': Meta})
814
815
816class LSMTable(VirtualModel):
817    class Meta:
818        extension_module = 'lsm1'
819        filename = None
820
821    @classmethod
822    def clean_options(cls, options):
823        filename = cls._meta.filename
824        if not filename:
825            raise ValueError('LSM1 extension requires that you specify a '
826                             'filename for the LSM database.')
827        else:
828            if len(filename) >= 2 and filename[0] != '"':
829                filename = '"%s"' % filename
830        if not cls._meta.primary_key:
831            raise ValueError('LSM1 models must specify a primary-key field.')
832
833        key = cls._meta.primary_key
834        if isinstance(key, AutoField):
835            raise ValueError('LSM1 models must explicitly declare a primary '
836                             'key field.')
837        if not isinstance(key, (TextField, BlobField, IntegerField)):
838            raise ValueError('LSM1 key must be a TextField, BlobField, or '
839                             'IntegerField.')
840        key._hidden = True
841        if isinstance(key, IntegerField):
842            data_type = 'UINT'
843        elif isinstance(key, BlobField):
844            data_type = 'BLOB'
845        else:
846            data_type = 'TEXT'
847        cls._meta.prefix_arguments = [filename, '"%s"' % key.name, data_type]
848
849        # Does the key map to a scalar value, or a tuple of values?
850        if len(cls._meta.sorted_fields) == 2:
851            cls._meta._value_field = cls._meta.sorted_fields[1]
852        else:
853            cls._meta._value_field = None
854
855        return options
856
857    @classmethod
858    def load_extension(cls, path='lsm.so'):
859        cls._meta.database.load_extension(path)
860
861    @staticmethod
862    def slice_to_expr(key, idx):
863        if idx.start is not None and idx.stop is not None:
864            return key.between(idx.start, idx.stop)
865        elif idx.start is not None:
866            return key >= idx.start
867        elif idx.stop is not None:
868            return key <= idx.stop
869
870    @staticmethod
871    def _apply_lookup_to_query(query, key, lookup):
872        if isinstance(lookup, slice):
873            expr = LSMTable.slice_to_expr(key, lookup)
874            if expr is not None:
875                query = query.where(expr)
876            return query, False
877        elif isinstance(lookup, Expression):
878            return query.where(lookup), False
879        else:
880            return query.where(key == lookup), True
881
882    @classmethod
883    def get_by_id(cls, pk):
884        query, is_single = cls._apply_lookup_to_query(
885            cls.select().namedtuples(),
886            cls._meta.primary_key,
887            pk)
888
889        if is_single:
890            try:
891                row = query.get()
892            except cls.DoesNotExist:
893                raise KeyError(pk)
894            return row[1] if cls._meta._value_field is not None else row
895        else:
896            return query
897
898    @classmethod
899    def set_by_id(cls, key, value):
900        if cls._meta._value_field is not None:
901            data = {cls._meta._value_field: value}
902        elif isinstance(value, tuple):
903            data = {}
904            for field, fval in zip(cls._meta.sorted_fields[1:], value):
905                data[field] = fval
906        elif isinstance(value, dict):
907            data = value
908        elif isinstance(value, cls):
909            data = value.__dict__
910        data[cls._meta.primary_key] = key
911        cls.replace(data).execute()
912
913    @classmethod
914    def delete_by_id(cls, pk):
915        query, is_single = cls._apply_lookup_to_query(
916            cls.delete(),
917            cls._meta.primary_key,
918            pk)
919        return query.execute()
920
921
922OP.MATCH = 'MATCH'
923
924def _sqlite_regexp(regex, value):
925    return re.search(regex, value) is not None
926
927
928class SqliteExtDatabase(SqliteDatabase):
929    def __init__(self, database, c_extensions=None, rank_functions=True,
930                 hash_functions=False, regexp_function=False,
931                 bloomfilter=False, json_contains=False, *args, **kwargs):
932        super(SqliteExtDatabase, self).__init__(database, *args, **kwargs)
933        self._row_factory = None
934
935        if c_extensions and not CYTHON_SQLITE_EXTENSIONS:
936            raise ImproperlyConfigured('SqliteExtDatabase initialized with '
937                                       'C extensions, but shared library was '
938                                       'not found!')
939        prefer_c = CYTHON_SQLITE_EXTENSIONS and (c_extensions is not False)
940        if rank_functions:
941            if prefer_c:
942                register_rank_functions(self)
943            else:
944                self.register_function(bm25, 'fts_bm25')
945                self.register_function(rank, 'fts_rank')
946                self.register_function(bm25, 'fts_bm25f')  # Fall back to bm25.
947                self.register_function(bm25, 'fts_lucene')
948        if hash_functions:
949            if not prefer_c:
950                raise ValueError('C extension required to register hash '
951                                 'functions.')
952            register_hash_functions(self)
953        if regexp_function:
954            self.register_function(_sqlite_regexp, 'regexp', 2)
955        if bloomfilter:
956            if not prefer_c:
957                raise ValueError('C extension required to use bloomfilter.')
958            register_bloomfilter(self)
959        if json_contains:
960            self.register_function(_json_contains, 'json_contains')
961
962        self._c_extensions = prefer_c
963
964    def _add_conn_hooks(self, conn):
965        super(SqliteExtDatabase, self)._add_conn_hooks(conn)
966        if self._row_factory:
967            conn.row_factory = self._row_factory
968
969    def row_factory(self, fn):
970        self._row_factory = fn
971
972
973if CYTHON_SQLITE_EXTENSIONS:
974    SQLITE_STATUS_MEMORY_USED = 0
975    SQLITE_STATUS_PAGECACHE_USED = 1
976    SQLITE_STATUS_PAGECACHE_OVERFLOW = 2
977    SQLITE_STATUS_SCRATCH_USED = 3
978    SQLITE_STATUS_SCRATCH_OVERFLOW = 4
979    SQLITE_STATUS_MALLOC_SIZE = 5
980    SQLITE_STATUS_PARSER_STACK = 6
981    SQLITE_STATUS_PAGECACHE_SIZE = 7
982    SQLITE_STATUS_SCRATCH_SIZE = 8
983    SQLITE_STATUS_MALLOC_COUNT = 9
984    SQLITE_DBSTATUS_LOOKASIDE_USED = 0
985    SQLITE_DBSTATUS_CACHE_USED = 1
986    SQLITE_DBSTATUS_SCHEMA_USED = 2
987    SQLITE_DBSTATUS_STMT_USED = 3
988    SQLITE_DBSTATUS_LOOKASIDE_HIT = 4
989    SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE = 5
990    SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL = 6
991    SQLITE_DBSTATUS_CACHE_HIT = 7
992    SQLITE_DBSTATUS_CACHE_MISS = 8
993    SQLITE_DBSTATUS_CACHE_WRITE = 9
994    SQLITE_DBSTATUS_DEFERRED_FKS = 10
995    #SQLITE_DBSTATUS_CACHE_USED_SHARED = 11
996
997    def __status__(flag, return_highwater=False):
998        """
999        Expose a sqlite3_status() call for a particular flag as a property of
1000        the Database object.
1001        """
1002        def getter(self):
1003            result = sqlite_get_status(flag)
1004            return result[1] if return_highwater else result
1005        return property(getter)
1006
1007    def __dbstatus__(flag, return_highwater=False, return_current=False):
1008        """
1009        Expose a sqlite3_dbstatus() call for a particular flag as a property of
1010        the Database instance. Unlike sqlite3_status(), the dbstatus properties
1011        pertain to the current connection.
1012        """
1013        def getter(self):
1014            if self._state.conn is None:
1015                raise ImproperlyConfigured('database connection not opened.')
1016            result = sqlite_get_db_status(self._state.conn, flag)
1017            if return_current:
1018                return result[0]
1019            return result[1] if return_highwater else result
1020        return property(getter)
1021
1022    class CSqliteExtDatabase(SqliteExtDatabase):
1023        def __init__(self, *args, **kwargs):
1024            self._conn_helper = None
1025            self._commit_hook = self._rollback_hook = self._update_hook = None
1026            self._replace_busy_handler = False
1027            super(CSqliteExtDatabase, self).__init__(*args, **kwargs)
1028
1029        def init(self, database, replace_busy_handler=False, **kwargs):
1030            super(CSqliteExtDatabase, self).init(database, **kwargs)
1031            self._replace_busy_handler = replace_busy_handler
1032
1033        def _close(self, conn):
1034            if self._commit_hook:
1035                self._conn_helper.set_commit_hook(None)
1036            if self._rollback_hook:
1037                self._conn_helper.set_rollback_hook(None)
1038            if self._update_hook:
1039                self._conn_helper.set_update_hook(None)
1040            return super(CSqliteExtDatabase, self)._close(conn)
1041
1042        def _add_conn_hooks(self, conn):
1043            super(CSqliteExtDatabase, self)._add_conn_hooks(conn)
1044            self._conn_helper = ConnectionHelper(conn)
1045            if self._commit_hook is not None:
1046                self._conn_helper.set_commit_hook(self._commit_hook)
1047            if self._rollback_hook is not None:
1048                self._conn_helper.set_rollback_hook(self._rollback_hook)
1049            if self._update_hook is not None:
1050                self._conn_helper.set_update_hook(self._update_hook)
1051            if self._replace_busy_handler:
1052                timeout = self._timeout or 5
1053                self._conn_helper.set_busy_handler(timeout * 1000)
1054
1055        def on_commit(self, fn):
1056            self._commit_hook = fn
1057            if not self.is_closed():
1058                self._conn_helper.set_commit_hook(fn)
1059            return fn
1060
1061        def on_rollback(self, fn):
1062            self._rollback_hook = fn
1063            if not self.is_closed():
1064                self._conn_helper.set_rollback_hook(fn)
1065            return fn
1066
1067        def on_update(self, fn):
1068            self._update_hook = fn
1069            if not self.is_closed():
1070                self._conn_helper.set_update_hook(fn)
1071            return fn
1072
1073        def changes(self):
1074            return self._conn_helper.changes()
1075
1076        @property
1077        def last_insert_rowid(self):
1078            return self._conn_helper.last_insert_rowid()
1079
1080        @property
1081        def autocommit(self):
1082            return self._conn_helper.autocommit()
1083
1084        def backup(self, destination, pages=None, name=None, progress=None):
1085            return backup(self.connection(), destination.connection(),
1086                          pages=pages, name=name, progress=progress)
1087
1088        def backup_to_file(self, filename, pages=None, name=None,
1089                           progress=None):
1090            return backup_to_file(self.connection(), filename, pages=pages,
1091                                  name=name, progress=progress)
1092
1093        def blob_open(self, table, column, rowid, read_only=False):
1094            return Blob(self, table, column, rowid, read_only)
1095
1096        # Status properties.
1097        memory_used = __status__(SQLITE_STATUS_MEMORY_USED)
1098        malloc_size = __status__(SQLITE_STATUS_MALLOC_SIZE, True)
1099        malloc_count = __status__(SQLITE_STATUS_MALLOC_COUNT)
1100        pagecache_used = __status__(SQLITE_STATUS_PAGECACHE_USED)
1101        pagecache_overflow = __status__(SQLITE_STATUS_PAGECACHE_OVERFLOW)
1102        pagecache_size = __status__(SQLITE_STATUS_PAGECACHE_SIZE, True)
1103        scratch_used = __status__(SQLITE_STATUS_SCRATCH_USED)
1104        scratch_overflow = __status__(SQLITE_STATUS_SCRATCH_OVERFLOW)
1105        scratch_size = __status__(SQLITE_STATUS_SCRATCH_SIZE, True)
1106
1107        # Connection status properties.
1108        lookaside_used = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_USED)
1109        lookaside_hit = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_HIT, True)
1110        lookaside_miss = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE,
1111                                      True)
1112        lookaside_miss_full = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL,
1113                                           True)
1114        cache_used = __dbstatus__(SQLITE_DBSTATUS_CACHE_USED, False, True)
1115        #cache_used_shared = __dbstatus__(SQLITE_DBSTATUS_CACHE_USED_SHARED,
1116        #                                 False, True)
1117        schema_used = __dbstatus__(SQLITE_DBSTATUS_SCHEMA_USED, False, True)
1118        statement_used = __dbstatus__(SQLITE_DBSTATUS_STMT_USED, False, True)
1119        cache_hit = __dbstatus__(SQLITE_DBSTATUS_CACHE_HIT, False, True)
1120        cache_miss = __dbstatus__(SQLITE_DBSTATUS_CACHE_MISS, False, True)
1121        cache_write = __dbstatus__(SQLITE_DBSTATUS_CACHE_WRITE, False, True)
1122
1123
1124def match(lhs, rhs):
1125    return Expression(lhs, OP.MATCH, rhs)
1126
1127def _parse_match_info(buf):
1128    # See http://sqlite.org/fts3.html#matchinfo
1129    bufsize = len(buf)  # Length in bytes.
1130    return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)]
1131
1132def get_weights(ncol, raw_weights):
1133    if not raw_weights:
1134        return [1] * ncol
1135    else:
1136        weights = [0] * ncol
1137        for i, weight in enumerate(raw_weights):
1138            weights[i] = weight
1139    return weights
1140
1141# Ranking implementation, which parse matchinfo.
1142def rank(raw_match_info, *raw_weights):
1143    # Handle match_info called w/default args 'pcx' - based on the example rank
1144    # function http://sqlite.org/fts3.html#appendix_a
1145    match_info = _parse_match_info(raw_match_info)
1146    score = 0.0
1147
1148    p, c = match_info[:2]
1149    weights = get_weights(c, raw_weights)
1150
1151    # matchinfo X value corresponds to, for each phrase in the search query, a
1152    # list of 3 values for each column in the search table.
1153    # So if we have a two-phrase search query and three columns of data, the
1154    # following would be the layout:
1155    # p0 : c0=[0, 1, 2],   c1=[3, 4, 5],    c2=[6, 7, 8]
1156    # p1 : c0=[9, 10, 11], c1=[12, 13, 14], c2=[15, 16, 17]
1157    for phrase_num in range(p):
1158        phrase_info_idx = 2 + (phrase_num * c * 3)
1159        for col_num in range(c):
1160            weight = weights[col_num]
1161            if not weight:
1162                continue
1163
1164            col_idx = phrase_info_idx + (col_num * 3)
1165
1166            # The idea is that we count the number of times the phrase appears
1167            # in this column of the current row, compared to how many times it
1168            # appears in this column across all rows. The ratio of these values
1169            # provides a rough way to score based on "high value" terms.
1170            row_hits = match_info[col_idx]
1171            all_rows_hits = match_info[col_idx + 1]
1172            if row_hits > 0:
1173                score += weight * (float(row_hits) / all_rows_hits)
1174
1175    return -score
1176
1177# Okapi BM25 ranking implementation (FTS4 only).
1178def bm25(raw_match_info, *args):
1179    """
1180    Usage:
1181
1182        # Format string *must* be pcnalx
1183        # Second parameter to bm25 specifies the index of the column, on
1184        # the table being queries.
1185        bm25(matchinfo(document_tbl, 'pcnalx'), 1) AS rank
1186    """
1187    match_info = _parse_match_info(raw_match_info)
1188    K = 1.2
1189    B = 0.75
1190    score = 0.0
1191
1192    P_O, C_O, N_O, A_O = range(4)  # Offsets into the matchinfo buffer.
1193    term_count = match_info[P_O]  # n
1194    col_count = match_info[C_O]
1195    total_docs = match_info[N_O]  # N
1196    L_O = A_O + col_count
1197    X_O = L_O + col_count
1198
1199    # Worked example of pcnalx for two columns and two phrases, 100 docs total.
1200    # {
1201    #   p  = 2
1202    #   c  = 2
1203    #   n  = 100
1204    #   a0 = 4   -- avg number of tokens for col0, e.g. title
1205    #   a1 = 40  -- avg number of tokens for col1, e.g. body
1206    #   l0 = 5   -- curr doc has 5 tokens in col0
1207    #   l1 = 30  -- curr doc has 30 tokens in col1
1208    #
1209    #   x000     -- hits this row for phrase0, col0
1210    #   x001     -- hits all rows for phrase0, col0
1211    #   x002     -- rows with phrase0 in col0 at least once
1212    #
1213    #   x010     -- hits this row for phrase0, col1
1214    #   x011     -- hits all rows for phrase0, col1
1215    #   x012     -- rows with phrase0 in col1 at least once
1216    #
1217    #   x100     -- hits this row for phrase1, col0
1218    #   x101     -- hits all rows for phrase1, col0
1219    #   x102     -- rows with phrase1 in col0 at least once
1220    #
1221    #   x110     -- hits this row for phrase1, col1
1222    #   x111     -- hits all rows for phrase1, col1
1223    #   x112     -- rows with phrase1 in col1 at least once
1224    # }
1225
1226    weights = get_weights(col_count, args)
1227
1228    for i in range(term_count):
1229        for j in range(col_count):
1230            weight = weights[j]
1231            if weight == 0:
1232                continue
1233
1234            x = X_O + (3 * (j + i * col_count))
1235            term_frequency = float(match_info[x])  # f(qi, D)
1236            docs_with_term = float(match_info[x + 2])  # n(qi)
1237
1238            # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) )
1239            idf = math.log(
1240                    (total_docs - docs_with_term + 0.5) /
1241                    (docs_with_term + 0.5))
1242            if idf <= 0.0:
1243                idf = 1e-6
1244
1245            doc_length = float(match_info[L_O + j])  # |D|
1246            avg_length = float(match_info[A_O + j]) or 1.  # avgdl
1247            ratio = doc_length / avg_length
1248
1249            num = term_frequency * (K + 1.0)
1250            b_part = 1.0 - B + (B * ratio)
1251            denom = term_frequency + (K * b_part)
1252
1253            pc_score = idf * (num / denom)
1254            score += (pc_score * weight)
1255
1256    return -score
1257
1258
1259def _json_contains(src_json, obj_json):
1260    stack = []
1261    try:
1262        stack.append((json.loads(obj_json), json.loads(src_json)))
1263    except:
1264        # Invalid JSON!
1265        return False
1266
1267    while stack:
1268        obj, src = stack.pop()
1269        if isinstance(src, dict):
1270            if isinstance(obj, dict):
1271                for key in obj:
1272                    if key not in src:
1273                        return False
1274                    stack.append((obj[key], src[key]))
1275            elif isinstance(obj, list):
1276                for item in obj:
1277                    if item not in src:
1278                        return False
1279            elif obj not in src:
1280                return False
1281        elif isinstance(src, list):
1282            if isinstance(obj, dict):
1283                return False
1284            elif isinstance(obj, list):
1285                try:
1286                    for i in range(len(obj)):
1287                        stack.append((obj[i], src[i]))
1288                except IndexError:
1289                    return False
1290            elif obj not in src:
1291                return False
1292        elif obj != src:
1293            return False
1294    return True
1295