1import hashlib
2import zlib
3
4cimport cython
5from cpython cimport datetime
6from cpython.bytes cimport PyBytes_AsStringAndSize
7from cpython.bytes cimport PyBytes_Check
8from cpython.bytes cimport PyBytes_FromStringAndSize
9from cpython.bytes cimport PyBytes_AS_STRING
10from cpython.object cimport PyObject
11from cpython.ref cimport Py_INCREF, Py_DECREF
12from cpython.unicode cimport PyUnicode_AsUTF8String
13from cpython.unicode cimport PyUnicode_Check
14from cpython.unicode cimport PyUnicode_DecodeUTF8
15from cpython.version cimport PY_MAJOR_VERSION
16from libc.float cimport DBL_MAX
17from libc.math cimport ceil, log, sqrt
18from libc.math cimport pow as cpow
19#from libc.stdint cimport ssize_t
20from libc.stdint cimport uint8_t
21from libc.stdint cimport uint32_t
22from libc.stdlib cimport calloc, free, malloc, rand
23from libc.string cimport memcpy, memset, strlen
24
25from peewee import InterfaceError
26from peewee import Node
27from peewee import OperationalError
28from peewee import sqlite3 as pysqlite
29
30import traceback
31
32
33cdef struct sqlite3_index_constraint:
34    int iColumn  # Column constrained, -1 for rowid.
35    unsigned char op  # Constraint operator.
36    unsigned char usable  # True if this constraint is usable.
37    int iTermOffset  # Used internally - xBestIndex should ignore.
38
39
40cdef struct sqlite3_index_orderby:
41    int iColumn
42    unsigned char desc
43
44
45cdef struct sqlite3_index_constraint_usage:
46    int argvIndex  # if > 0, constraint is part of argv to xFilter.
47    unsigned char omit
48
49
50cdef extern from "sqlite3.h" nogil:
51    ctypedef struct sqlite3:
52        int busyTimeout
53    ctypedef struct sqlite3_backup
54    ctypedef struct sqlite3_blob
55    ctypedef struct sqlite3_context
56    ctypedef struct sqlite3_value
57    ctypedef long long sqlite3_int64
58    ctypedef unsigned long long sqlite_uint64
59
60    # Virtual tables.
61    ctypedef struct sqlite3_module  # Forward reference.
62    ctypedef struct sqlite3_vtab:
63        const sqlite3_module *pModule
64        int nRef
65        char *zErrMsg
66    ctypedef struct sqlite3_vtab_cursor:
67        sqlite3_vtab *pVtab
68
69    ctypedef struct sqlite3_index_info:
70        int nConstraint
71        sqlite3_index_constraint *aConstraint
72        int nOrderBy
73        sqlite3_index_orderby *aOrderBy
74        sqlite3_index_constraint_usage *aConstraintUsage
75        int idxNum
76        char *idxStr
77        int needToFreeIdxStr
78        int orderByConsumed
79        double estimatedCost
80        sqlite3_int64 estimatedRows
81        int idxFlags
82
83    ctypedef struct sqlite3_module:
84        int iVersion
85        int (*xCreate)(sqlite3*, void *pAux, int argc, const char *const*argv,
86                       sqlite3_vtab **ppVTab, char**)
87        int (*xConnect)(sqlite3*, void *pAux, int argc, const char *const*argv,
88                        sqlite3_vtab **ppVTab, char**)
89        int (*xBestIndex)(sqlite3_vtab *pVTab, sqlite3_index_info*)
90        int (*xDisconnect)(sqlite3_vtab *pVTab)
91        int (*xDestroy)(sqlite3_vtab *pVTab)
92        int (*xOpen)(sqlite3_vtab *pVTab, sqlite3_vtab_cursor **ppCursor)
93        int (*xClose)(sqlite3_vtab_cursor*)
94        int (*xFilter)(sqlite3_vtab_cursor*, int idxNum, const char *idxStr,
95                       int argc, sqlite3_value **argv)
96        int (*xNext)(sqlite3_vtab_cursor*)
97        int (*xEof)(sqlite3_vtab_cursor*)
98        int (*xColumn)(sqlite3_vtab_cursor*, sqlite3_context *, int)
99        int (*xRowid)(sqlite3_vtab_cursor*, sqlite3_int64 *pRowid)
100        int (*xUpdate)(sqlite3_vtab *pVTab, int, sqlite3_value **,
101                       sqlite3_int64 **)
102        int (*xBegin)(sqlite3_vtab *pVTab)
103        int (*xSync)(sqlite3_vtab *pVTab)
104        int (*xCommit)(sqlite3_vtab *pVTab)
105        int (*xRollback)(sqlite3_vtab *pVTab)
106        int (*xFindFunction)(sqlite3_vtab *pVTab, int nArg, const char *zName,
107                             void (**pxFunc)(sqlite3_context *, int,
108                                             sqlite3_value **),
109                             void **ppArg)
110        int (*xRename)(sqlite3_vtab *pVTab, const char *zNew)
111        int (*xSavepoint)(sqlite3_vtab *pVTab, int)
112        int (*xRelease)(sqlite3_vtab *pVTab, int)
113        int (*xRollbackTo)(sqlite3_vtab *pVTab, int)
114
115    cdef int sqlite3_declare_vtab(sqlite3 *db, const char *zSQL)
116    cdef int sqlite3_create_module(sqlite3 *db, const char *zName,
117                                   const sqlite3_module *p, void *pClientData)
118
119    cdef const char sqlite3_version[]
120
121    # Encoding.
122    cdef int SQLITE_UTF8 = 1
123
124    # Return values.
125    cdef int SQLITE_OK = 0
126    cdef int SQLITE_ERROR = 1
127    cdef int SQLITE_INTERNAL = 2
128    cdef int SQLITE_PERM = 3
129    cdef int SQLITE_ABORT = 4
130    cdef int SQLITE_BUSY = 5
131    cdef int SQLITE_LOCKED = 6
132    cdef int SQLITE_NOMEM = 7
133    cdef int SQLITE_READONLY = 8
134    cdef int SQLITE_INTERRUPT = 9
135    cdef int SQLITE_DONE = 101
136
137    # Function type.
138    cdef int SQLITE_DETERMINISTIC = 0x800
139
140    # Types of filtering operations.
141    cdef int SQLITE_INDEX_CONSTRAINT_EQ = 2
142    cdef int SQLITE_INDEX_CONSTRAINT_GT = 4
143    cdef int SQLITE_INDEX_CONSTRAINT_LE = 8
144    cdef int SQLITE_INDEX_CONSTRAINT_LT = 16
145    cdef int SQLITE_INDEX_CONSTRAINT_GE = 32
146    cdef int SQLITE_INDEX_CONSTRAINT_MATCH = 64
147
148    # sqlite_value_type.
149    cdef int SQLITE_INTEGER = 1
150    cdef int SQLITE_FLOAT   = 2
151    cdef int SQLITE3_TEXT   = 3
152    cdef int SQLITE_TEXT    = 3
153    cdef int SQLITE_BLOB    = 4
154    cdef int SQLITE_NULL    = 5
155
156    ctypedef void (*sqlite3_destructor_type)(void*)
157
158    # Converting from Sqlite -> Python.
159    cdef const void *sqlite3_value_blob(sqlite3_value*)
160    cdef int sqlite3_value_bytes(sqlite3_value*)
161    cdef double sqlite3_value_double(sqlite3_value*)
162    cdef int sqlite3_value_int(sqlite3_value*)
163    cdef sqlite3_int64 sqlite3_value_int64(sqlite3_value*)
164    cdef const unsigned char *sqlite3_value_text(sqlite3_value*)
165    cdef int sqlite3_value_type(sqlite3_value*)
166    cdef int sqlite3_value_numeric_type(sqlite3_value*)
167
168    # Converting from Python -> Sqlite.
169    cdef void sqlite3_result_blob(sqlite3_context*, const void *, int,
170                                  void(*)(void*))
171    cdef void sqlite3_result_double(sqlite3_context*, double)
172    cdef void sqlite3_result_error(sqlite3_context*, const char*, int)
173    cdef void sqlite3_result_error_toobig(sqlite3_context*)
174    cdef void sqlite3_result_error_nomem(sqlite3_context*)
175    cdef void sqlite3_result_error_code(sqlite3_context*, int)
176    cdef void sqlite3_result_int(sqlite3_context*, int)
177    cdef void sqlite3_result_int64(sqlite3_context*, sqlite3_int64)
178    cdef void sqlite3_result_null(sqlite3_context*)
179    cdef void sqlite3_result_text(sqlite3_context*, const char*, int,
180                                  void(*)(void*))
181    cdef void sqlite3_result_value(sqlite3_context*, sqlite3_value*)
182
183    # Memory management.
184    cdef void* sqlite3_malloc(int)
185    cdef void sqlite3_free(void *)
186
187    cdef int sqlite3_changes(sqlite3 *db)
188    cdef int sqlite3_get_autocommit(sqlite3 *db)
189    cdef sqlite3_int64 sqlite3_last_insert_rowid(sqlite3 *db)
190
191    cdef void *sqlite3_commit_hook(sqlite3 *, int(*)(void *), void *)
192    cdef void *sqlite3_rollback_hook(sqlite3 *, void(*)(void *), void *)
193    cdef void *sqlite3_update_hook(
194        sqlite3 *,
195        void(*)(void *, int, char *, char *, sqlite3_int64),
196        void *)
197
198    cdef int SQLITE_STATUS_MEMORY_USED = 0
199    cdef int SQLITE_STATUS_PAGECACHE_USED = 1
200    cdef int SQLITE_STATUS_PAGECACHE_OVERFLOW = 2
201    cdef int SQLITE_STATUS_SCRATCH_USED = 3
202    cdef int SQLITE_STATUS_SCRATCH_OVERFLOW = 4
203    cdef int SQLITE_STATUS_MALLOC_SIZE = 5
204    cdef int SQLITE_STATUS_PARSER_STACK = 6
205    cdef int SQLITE_STATUS_PAGECACHE_SIZE = 7
206    cdef int SQLITE_STATUS_SCRATCH_SIZE = 8
207    cdef int SQLITE_STATUS_MALLOC_COUNT = 9
208    cdef int sqlite3_status(int op, int *pCurrent, int *pHighwater, int resetFlag)
209
210    cdef int SQLITE_DBSTATUS_LOOKASIDE_USED = 0
211    cdef int SQLITE_DBSTATUS_CACHE_USED = 1
212    cdef int SQLITE_DBSTATUS_SCHEMA_USED = 2
213    cdef int SQLITE_DBSTATUS_STMT_USED = 3
214    cdef int SQLITE_DBSTATUS_LOOKASIDE_HIT = 4
215    cdef int SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE = 5
216    cdef int SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL = 6
217    cdef int SQLITE_DBSTATUS_CACHE_HIT = 7
218    cdef int SQLITE_DBSTATUS_CACHE_MISS = 8
219    cdef int SQLITE_DBSTATUS_CACHE_WRITE = 9
220    cdef int SQLITE_DBSTATUS_DEFERRED_FKS = 10
221    #cdef int SQLITE_DBSTATUS_CACHE_USED_SHARED = 11
222    cdef int sqlite3_db_status(sqlite3 *, int op, int *pCur, int *pHigh, int reset)
223
224    cdef int SQLITE_DELETE = 9
225    cdef int SQLITE_INSERT = 18
226    cdef int SQLITE_UPDATE = 23
227
228    cdef int SQLITE_CONFIG_SINGLETHREAD = 1  # None
229    cdef int SQLITE_CONFIG_MULTITHREAD = 2  # None
230    cdef int SQLITE_CONFIG_SERIALIZED = 3  # None
231    cdef int SQLITE_CONFIG_SCRATCH = 6  # void *, int sz, int N
232    cdef int SQLITE_CONFIG_PAGECACHE = 7  # void *, int sz, int N
233    cdef int SQLITE_CONFIG_HEAP = 8  # void *, int nByte, int min
234    cdef int SQLITE_CONFIG_MEMSTATUS = 9  # boolean
235    cdef int SQLITE_CONFIG_LOOKASIDE = 13  # int, int
236    cdef int SQLITE_CONFIG_URI = 17  # int
237    cdef int SQLITE_CONFIG_MMAP_SIZE = 22  # sqlite3_int64, sqlite3_int64
238    cdef int SQLITE_CONFIG_STMTJRNL_SPILL = 26  # int nByte
239    cdef int SQLITE_DBCONFIG_MAINDBNAME = 1000  # const char*
240    cdef int SQLITE_DBCONFIG_LOOKASIDE = 1001  # void* int int
241    cdef int SQLITE_DBCONFIG_ENABLE_FKEY = 1002  # int int*
242    cdef int SQLITE_DBCONFIG_ENABLE_TRIGGER = 1003  # int int*
243    cdef int SQLITE_DBCONFIG_ENABLE_FTS3_TOKENIZER = 1004  # int int*
244    cdef int SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION = 1005  # int int*
245    cdef int SQLITE_DBCONFIG_NO_CKPT_ON_CLOSE = 1006  # int int*
246    cdef int SQLITE_DBCONFIG_ENABLE_QPSG = 1007  # int int*
247
248    cdef int sqlite3_config(int, ...)
249    cdef int sqlite3_db_config(sqlite3*, int op, ...)
250
251    # Misc.
252    cdef int sqlite3_busy_handler(sqlite3 *db, int(*)(void *, int), void *)
253    cdef int sqlite3_sleep(int ms)
254    cdef sqlite3_backup *sqlite3_backup_init(
255        sqlite3 *pDest,
256        const char *zDestName,
257        sqlite3 *pSource,
258        const char *zSourceName)
259
260    # Backup.
261    cdef int sqlite3_backup_step(sqlite3_backup *p, int nPage)
262    cdef int sqlite3_backup_finish(sqlite3_backup *p)
263    cdef int sqlite3_backup_remaining(sqlite3_backup *p)
264    cdef int sqlite3_backup_pagecount(sqlite3_backup *p)
265
266    # Error handling.
267    cdef int sqlite3_errcode(sqlite3 *db)
268    cdef int sqlite3_errstr(int)
269    cdef const char *sqlite3_errmsg(sqlite3 *db)
270
271    cdef int sqlite3_blob_open(
272          sqlite3*,
273          const char *zDb,
274          const char *zTable,
275          const char *zColumn,
276          sqlite3_int64 iRow,
277          int flags,
278          sqlite3_blob **ppBlob)
279    cdef int sqlite3_blob_reopen(sqlite3_blob *, sqlite3_int64)
280    cdef int sqlite3_blob_close(sqlite3_blob *)
281    cdef int sqlite3_blob_bytes(sqlite3_blob *)
282    cdef int sqlite3_blob_read(sqlite3_blob *, void *Z, int N, int iOffset)
283    cdef int sqlite3_blob_write(sqlite3_blob *, const void *z, int n,
284                                int iOffset)
285
286
287cdef extern from "_pysqlite/connection.h":
288    ctypedef struct pysqlite_Connection:
289        sqlite3* db
290        double timeout
291        int initialized
292
293
294cdef sqlite_to_python(int argc, sqlite3_value **params):
295    cdef:
296        int i
297        int vtype
298        list pyargs = []
299
300    for i in range(argc):
301        vtype = sqlite3_value_type(params[i])
302        if vtype == SQLITE_INTEGER:
303            pyval = sqlite3_value_int(params[i])
304        elif vtype == SQLITE_FLOAT:
305            pyval = sqlite3_value_double(params[i])
306        elif vtype == SQLITE_TEXT:
307            pyval = PyUnicode_DecodeUTF8(
308                <const char *>sqlite3_value_text(params[i]),
309                <Py_ssize_t>sqlite3_value_bytes(params[i]), NULL)
310        elif vtype == SQLITE_BLOB:
311            pyval = PyBytes_FromStringAndSize(
312                <const char *>sqlite3_value_blob(params[i]),
313                <Py_ssize_t>sqlite3_value_bytes(params[i]))
314        elif vtype == SQLITE_NULL:
315            pyval = None
316        else:
317            pyval = None
318
319        pyargs.append(pyval)
320
321    return pyargs
322
323
324cdef python_to_sqlite(sqlite3_context *context, value):
325    if value is None:
326        sqlite3_result_null(context)
327    elif isinstance(value, (int, long)):
328        sqlite3_result_int64(context, <sqlite3_int64>value)
329    elif isinstance(value, float):
330        sqlite3_result_double(context, <double>value)
331    elif isinstance(value, unicode):
332        bval = PyUnicode_AsUTF8String(value)
333        sqlite3_result_text(
334            context,
335            <const char *>bval,
336            len(bval),
337            <sqlite3_destructor_type>-1)
338    elif isinstance(value, bytes):
339        if PY_MAJOR_VERSION > 2:
340            sqlite3_result_blob(
341                context,
342                <void *>(<char *>value),
343                len(value),
344                <sqlite3_destructor_type>-1)
345        else:
346            sqlite3_result_text(
347                context,
348                <const char *>value,
349                len(value),
350                <sqlite3_destructor_type>-1)
351    else:
352        sqlite3_result_error(
353            context,
354            encode('Unsupported type %s' % type(value)),
355            -1)
356        return SQLITE_ERROR
357
358    return SQLITE_OK
359
360
361cdef int SQLITE_CONSTRAINT = 19  # Abort due to constraint violation.
362
363USE_SQLITE_CONSTRAINT = sqlite3_version[:4] >= b'3.26'
364
365# The peewee_vtab struct embeds the base sqlite3_vtab struct, and adds a field
366# to store a reference to the Python implementation.
367ctypedef struct peewee_vtab:
368    sqlite3_vtab base
369    void *table_func_cls
370
371
372# Like peewee_vtab, the peewee_cursor embeds the base sqlite3_vtab_cursor and
373# adds fields to store references to the current index, the Python
374# implementation, the current rows' data, and a flag for whether the cursor has
375# been exhausted.
376ctypedef struct peewee_cursor:
377    sqlite3_vtab_cursor base
378    long long idx
379    void *table_func
380    void *row_data
381    bint stopped
382
383
384# We define an xConnect function, but leave xCreate NULL so that the
385# table-function can be called eponymously.
386cdef int pwConnect(sqlite3 *db, void *pAux, int argc, const char *const*argv,
387                   sqlite3_vtab **ppVtab, char **pzErr) with gil:
388    cdef:
389        int rc
390        object table_func_cls = <object>pAux
391        peewee_vtab *pNew = <peewee_vtab *>0
392
393    rc = sqlite3_declare_vtab(
394        db,
395        encode('CREATE TABLE x(%s);' %
396               table_func_cls.get_table_columns_declaration()))
397    if rc == SQLITE_OK:
398        pNew = <peewee_vtab *>sqlite3_malloc(sizeof(pNew[0]))
399        memset(<char *>pNew, 0, sizeof(pNew[0]))
400        ppVtab[0] = &(pNew.base)
401
402        pNew.table_func_cls = <void *>table_func_cls
403        Py_INCREF(table_func_cls)
404
405    return rc
406
407
408cdef int pwDisconnect(sqlite3_vtab *pBase) with gil:
409    cdef:
410        peewee_vtab *pVtab = <peewee_vtab *>pBase
411        object table_func_cls = <object>(pVtab.table_func_cls)
412
413    Py_DECREF(table_func_cls)
414    sqlite3_free(pVtab)
415    return SQLITE_OK
416
417
418# The xOpen method is used to initialize a cursor. In this method we
419# instantiate the TableFunction class and zero out a new cursor for iteration.
420cdef int pwOpen(sqlite3_vtab *pBase, sqlite3_vtab_cursor **ppCursor) with gil:
421    cdef:
422        peewee_vtab *pVtab = <peewee_vtab *>pBase
423        peewee_cursor *pCur = <peewee_cursor *>0
424        object table_func_cls = <object>pVtab.table_func_cls
425
426    pCur = <peewee_cursor *>sqlite3_malloc(sizeof(pCur[0]))
427    memset(<char *>pCur, 0, sizeof(pCur[0]))
428    ppCursor[0] = &(pCur.base)
429    pCur.idx = 0
430    try:
431        table_func = table_func_cls()
432    except:
433        if table_func_cls.print_tracebacks:
434            traceback.print_exc()
435        sqlite3_free(pCur)
436        return SQLITE_ERROR
437
438    Py_INCREF(table_func)
439    pCur.table_func = <void *>table_func
440    pCur.stopped = False
441    return SQLITE_OK
442
443
444cdef int pwClose(sqlite3_vtab_cursor *pBase) with gil:
445    cdef:
446        peewee_cursor *pCur = <peewee_cursor *>pBase
447        object table_func = <object>pCur.table_func
448    Py_DECREF(table_func)
449    sqlite3_free(pCur)
450    return SQLITE_OK
451
452
453# Iterate once, advancing the cursor's index and assigning the row data to the
454# `row_data` field on the peewee_cursor struct.
455cdef int pwNext(sqlite3_vtab_cursor *pBase) with gil:
456    cdef:
457        peewee_cursor *pCur = <peewee_cursor *>pBase
458        object table_func = <object>pCur.table_func
459        tuple result
460
461    if pCur.row_data:
462        Py_DECREF(<tuple>pCur.row_data)
463
464    pCur.row_data = NULL
465    try:
466        result = tuple(table_func.iterate(pCur.idx))
467    except StopIteration:
468        pCur.stopped = True
469    except:
470        if table_func.print_tracebacks:
471            traceback.print_exc()
472        return SQLITE_ERROR
473    else:
474        Py_INCREF(result)
475        pCur.row_data = <void *>result
476        pCur.idx += 1
477        pCur.stopped = False
478
479    return SQLITE_OK
480
481
482# Return the requested column from the current row.
483cdef int pwColumn(sqlite3_vtab_cursor *pBase, sqlite3_context *ctx,
484                  int iCol) with gil:
485    cdef:
486        bytes bval
487        peewee_cursor *pCur = <peewee_cursor *>pBase
488        sqlite3_int64 x = 0
489        tuple row_data
490
491    if iCol == -1:
492        sqlite3_result_int64(ctx, <sqlite3_int64>pCur.idx)
493        return SQLITE_OK
494
495    if not pCur.row_data:
496        sqlite3_result_error(ctx, encode('no row data'), -1)
497        return SQLITE_ERROR
498
499    row_data = <tuple>pCur.row_data
500    return python_to_sqlite(ctx, row_data[iCol])
501
502
503cdef int pwRowid(sqlite3_vtab_cursor *pBase, sqlite3_int64 *pRowid):
504    cdef:
505        peewee_cursor *pCur = <peewee_cursor *>pBase
506    pRowid[0] = <sqlite3_int64>pCur.idx
507    return SQLITE_OK
508
509
510# Return a boolean indicating whether the cursor has been consumed.
511cdef int pwEof(sqlite3_vtab_cursor *pBase):
512    cdef:
513        peewee_cursor *pCur = <peewee_cursor *>pBase
514    return 1 if pCur.stopped else 0
515
516
517# The filter method is called on the first iteration. This method is where we
518# get access to the parameters that the function was called with, and call the
519# TableFunction's `initialize()` function.
520cdef int pwFilter(sqlite3_vtab_cursor *pBase, int idxNum,
521                  const char *idxStr, int argc, sqlite3_value **argv) with gil:
522    cdef:
523        peewee_cursor *pCur = <peewee_cursor *>pBase
524        object table_func = <object>pCur.table_func
525        dict query = {}
526        int idx
527        int value_type
528        tuple row_data
529        void *row_data_raw
530
531    if not idxStr or argc == 0 and len(table_func.params):
532        return SQLITE_ERROR
533    elif len(idxStr):
534        params = decode(idxStr).split(',')
535    else:
536        params = []
537
538    py_values = sqlite_to_python(argc, argv)
539
540    for idx, param in enumerate(params):
541        value = argv[idx]
542        if not value:
543            query[param] = None
544        else:
545            query[param] = py_values[idx]
546
547    try:
548        table_func.initialize(**query)
549    except:
550        if table_func.print_tracebacks:
551            traceback.print_exc()
552        return SQLITE_ERROR
553
554    pCur.stopped = False
555    try:
556        row_data = tuple(table_func.iterate(0))
557    except StopIteration:
558        pCur.stopped = True
559    except:
560        if table_func.print_tracebacks:
561            traceback.print_exc()
562        return SQLITE_ERROR
563    else:
564        Py_INCREF(row_data)
565        pCur.row_data = <void *>row_data
566        pCur.idx += 1
567    return SQLITE_OK
568
569
570# SQLite will (in some cases, repeatedly) call the xBestIndex method to try and
571# find the best query plan.
572cdef int pwBestIndex(sqlite3_vtab *pBase, sqlite3_index_info *pIdxInfo) \
573        with gil:
574    cdef:
575        int i
576        int idxNum = 0, nArg = 0
577        peewee_vtab *pVtab = <peewee_vtab *>pBase
578        object table_func_cls = <object>pVtab.table_func_cls
579        sqlite3_index_constraint *pConstraint = <sqlite3_index_constraint *>0
580        list columns = []
581        char *idxStr
582        int nParams = len(table_func_cls.params)
583
584    for i in range(pIdxInfo.nConstraint):
585        pConstraint = pIdxInfo.aConstraint + i
586        if not pConstraint.usable:
587            continue
588        if pConstraint.op != SQLITE_INDEX_CONSTRAINT_EQ:
589            continue
590
591        columns.append(table_func_cls.params[pConstraint.iColumn -
592                                             table_func_cls._ncols])
593        nArg += 1
594        pIdxInfo.aConstraintUsage[i].argvIndex = nArg
595        pIdxInfo.aConstraintUsage[i].omit = 1
596
597    if nArg > 0 or nParams == 0:
598        if nArg == nParams:
599            # All parameters are present, this is ideal.
600            pIdxInfo.estimatedCost = <double>1
601            pIdxInfo.estimatedRows = 10
602        else:
603            # Penalize score based on number of missing params.
604            pIdxInfo.estimatedCost = <double>10000000000000 * <double>(nParams - nArg)
605            pIdxInfo.estimatedRows = 10 ** (nParams - nArg)
606
607        # Store a reference to the columns in the index info structure.
608        joinedCols = encode(','.join(columns))
609        idxStr = <char *>sqlite3_malloc((len(joinedCols) + 1) * sizeof(char))
610        memcpy(idxStr, <char *>joinedCols, len(joinedCols))
611        idxStr[len(joinedCols)] = '\x00'
612        pIdxInfo.idxStr = idxStr
613        pIdxInfo.needToFreeIdxStr = 0
614    elif USE_SQLITE_CONSTRAINT:
615        return SQLITE_CONSTRAINT
616    else:
617        pIdxInfo.estimatedCost = DBL_MAX
618        pIdxInfo.estimatedRows = 100000
619    return SQLITE_OK
620
621
622cdef class _TableFunctionImpl(object):
623    cdef:
624        sqlite3_module module
625        object table_function
626
627    def __cinit__(self, table_function):
628        self.table_function = table_function
629
630    cdef create_module(self, pysqlite_Connection* sqlite_conn):
631        cdef:
632            bytes name = encode(self.table_function.name)
633            sqlite3 *db = sqlite_conn.db
634            int rc
635
636        # Populate the SQLite module struct members.
637        self.module.iVersion = 0
638        self.module.xCreate = NULL
639        self.module.xConnect = pwConnect
640        self.module.xBestIndex = pwBestIndex
641        self.module.xDisconnect = pwDisconnect
642        self.module.xDestroy = NULL
643        self.module.xOpen = pwOpen
644        self.module.xClose = pwClose
645        self.module.xFilter = pwFilter
646        self.module.xNext = pwNext
647        self.module.xEof = pwEof
648        self.module.xColumn = pwColumn
649        self.module.xRowid = pwRowid
650        self.module.xUpdate = NULL
651        self.module.xBegin = NULL
652        self.module.xSync = NULL
653        self.module.xCommit = NULL
654        self.module.xRollback = NULL
655        self.module.xFindFunction = NULL
656        self.module.xRename = NULL
657
658        # Create the SQLite virtual table.
659        rc = sqlite3_create_module(
660            db,
661            <const char *>name,
662            &self.module,
663            <void *>(self.table_function))
664
665        Py_INCREF(self)
666
667        return rc == SQLITE_OK
668
669
670class TableFunction(object):
671    columns = None
672    params = None
673    name = None
674    print_tracebacks = True
675    _ncols = None
676
677    @classmethod
678    def register(cls, conn):
679        cdef _TableFunctionImpl impl = _TableFunctionImpl(cls)
680        impl.create_module(<pysqlite_Connection *>conn)
681        cls._ncols = len(cls.columns)
682
683    def initialize(self, **filters):
684        raise NotImplementedError
685
686    def iterate(self, idx):
687        raise NotImplementedError
688
689    @classmethod
690    def get_table_columns_declaration(cls):
691        cdef list accum = []
692
693        for column in cls.columns:
694            if isinstance(column, tuple):
695                if len(column) != 2:
696                    raise ValueError('Column must be either a string or a '
697                                     '2-tuple of name, type')
698                accum.append('%s %s' % column)
699            else:
700                accum.append(column)
701
702        for param in cls.params:
703            accum.append('%s HIDDEN' % param)
704
705        return ', '.join(accum)
706
707
708cdef tuple SQLITE_DATETIME_FORMATS = (
709    '%Y-%m-%d %H:%M:%S',
710    '%Y-%m-%d %H:%M:%S.%f',
711    '%Y-%m-%d',
712    '%H:%M:%S',
713    '%H:%M:%S.%f',
714    '%H:%M')
715
716cdef dict SQLITE_DATE_TRUNC_MAPPING = {
717    'year': '%Y',
718    'month': '%Y-%m',
719    'day': '%Y-%m-%d',
720    'hour': '%Y-%m-%d %H',
721    'minute': '%Y-%m-%d %H:%M',
722    'second': '%Y-%m-%d %H:%M:%S'}
723
724
725cdef tuple validate_and_format_datetime(lookup, date_str):
726    if not date_str or not lookup:
727        return
728
729    lookup = lookup.lower()
730    if lookup not in SQLITE_DATE_TRUNC_MAPPING:
731        return
732
733    cdef datetime.datetime date_obj
734    cdef bint success = False
735
736    for date_format in SQLITE_DATETIME_FORMATS:
737        try:
738            date_obj = datetime.datetime.strptime(date_str, date_format)
739        except ValueError:
740            pass
741        else:
742            return (date_obj, lookup)
743
744
745cdef inline bytes encode(key):
746    cdef bytes bkey
747    if PyUnicode_Check(key):
748        bkey = PyUnicode_AsUTF8String(key)
749    elif PyBytes_Check(key):
750        bkey = <bytes>key
751    elif key is None:
752        return None
753    else:
754        bkey = PyUnicode_AsUTF8String(str(key))
755    return bkey
756
757
758cdef inline unicode decode(key):
759    cdef unicode ukey
760    if PyBytes_Check(key):
761        ukey = key.decode('utf-8')
762    elif PyUnicode_Check(key):
763        ukey = <unicode>key
764    elif key is None:
765        return None
766    else:
767        ukey = unicode(key)
768    return ukey
769
770
771cdef double *get_weights(int ncol, tuple raw_weights):
772    cdef:
773        int argc = len(raw_weights)
774        int icol
775        double *weights = <double *>malloc(sizeof(double) * ncol)
776
777    for icol in range(ncol):
778        if argc == 0:
779            weights[icol] = 1.0
780        elif icol < argc:
781            weights[icol] = <double>raw_weights[icol]
782        else:
783            weights[icol] = 0.0
784    return weights
785
786
787def peewee_rank(py_match_info, *raw_weights):
788    cdef:
789        unsigned int *match_info
790        unsigned int *phrase_info
791        bytes _match_info_buf = bytes(py_match_info)
792        char *match_info_buf = _match_info_buf
793        int nphrase, ncol, icol, iphrase, hits, global_hits
794        int P_O = 0, C_O = 1, X_O = 2
795        double score = 0.0, weight
796        double *weights
797
798    match_info = <unsigned int *>match_info_buf
799    nphrase = match_info[P_O]
800    ncol = match_info[C_O]
801    weights = get_weights(ncol, raw_weights)
802
803    # matchinfo X value corresponds to, for each phrase in the search query, a
804    # list of 3 values for each column in the search table.
805    # So if we have a two-phrase search query and three columns of data, the
806    # following would be the layout:
807    # p0 : c0=[0, 1, 2],   c1=[3, 4, 5],    c2=[6, 7, 8]
808    # p1 : c0=[9, 10, 11], c1=[12, 13, 14], c2=[15, 16, 17]
809    for iphrase in range(nphrase):
810        phrase_info = &match_info[X_O + iphrase * ncol * 3]
811        for icol in range(ncol):
812            weight = weights[icol]
813            if weight == 0:
814                continue
815
816            # The idea is that we count the number of times the phrase appears
817            # in this column of the current row, compared to how many times it
818            # appears in this column across all rows. The ratio of these values
819            # provides a rough way to score based on "high value" terms.
820            hits = phrase_info[3 * icol]
821            global_hits = phrase_info[3 * icol + 1]
822            if hits > 0:
823                score += weight * (<double>hits / <double>global_hits)
824
825    free(weights)
826    return -1 * score
827
828
829def peewee_lucene(py_match_info, *raw_weights):
830    # Usage: peewee_lucene(matchinfo(table, 'pcnalx'), 1)
831    cdef:
832        unsigned int *match_info
833        bytes _match_info_buf = bytes(py_match_info)
834        char *match_info_buf = _match_info_buf
835        int nphrase, ncol
836        double total_docs, term_frequency
837        double doc_length, docs_with_term, avg_length
838        double idf, weight, rhs, denom
839        double *weights
840        int P_O = 0, C_O = 1, N_O = 2, L_O, X_O
841        int iphrase, icol, x
842        double score = 0.0
843
844    match_info = <unsigned int *>match_info_buf
845    nphrase = match_info[P_O]
846    ncol = match_info[C_O]
847    total_docs = match_info[N_O]
848
849    L_O = 3 + ncol
850    X_O = L_O + ncol
851    weights = get_weights(ncol, raw_weights)
852
853    for iphrase in range(nphrase):
854        for icol in range(ncol):
855            weight = weights[icol]
856            if weight == 0:
857                continue
858            doc_length = match_info[L_O + icol]
859            x = X_O + (3 * (icol + iphrase * ncol))
860            term_frequency = match_info[x]  # f(qi)
861            docs_with_term = match_info[x + 2] or 1. # n(qi)
862            idf = log(total_docs / (docs_with_term + 1.))
863            tf = sqrt(term_frequency)
864            fieldNorms = 1.0 / sqrt(doc_length)
865            score += (idf * tf * fieldNorms)
866
867    free(weights)
868    return -1 * score
869
870
871def peewee_bm25(py_match_info, *raw_weights):
872    # Usage: peewee_bm25(matchinfo(table, 'pcnalx'), 1)
873    # where the second parameter is the index of the column and
874    # the 3rd and 4th specify k and b.
875    cdef:
876        unsigned int *match_info
877        bytes _match_info_buf = bytes(py_match_info)
878        char *match_info_buf = _match_info_buf
879        int nphrase, ncol
880        double B = 0.75, K = 1.2
881        double total_docs, term_frequency
882        double doc_length, docs_with_term, avg_length
883        double idf, weight, ratio, num, b_part, denom, pc_score
884        double *weights
885        int P_O = 0, C_O = 1, N_O = 2, A_O = 3, L_O, X_O
886        int iphrase, icol, x
887        double score = 0.0
888
889    match_info = <unsigned int *>match_info_buf
890    # PCNALX = matchinfo format.
891    # P = 1 = phrase count within query.
892    # C = 1 = searchable columns in table.
893    # N = 1 = total rows in table.
894    # A = c = for each column, avg number of tokens
895    # L = c = for each column, length of current row (in tokens)
896    # X = 3 * c * p = for each phrase and table column,
897    # * phrase count within column for current row.
898    # * phrase count within column for all rows.
899    # * total rows for which column contains phrase.
900    nphrase = match_info[P_O]  # n
901    ncol = match_info[C_O]
902    total_docs = match_info[N_O]  # N
903
904    L_O = A_O + ncol
905    X_O = L_O + ncol
906    weights = get_weights(ncol, raw_weights)
907
908    for iphrase in range(nphrase):
909        for icol in range(ncol):
910            weight = weights[icol]
911            if weight == 0:
912                continue
913
914            x = X_O + (3 * (icol + iphrase * ncol))
915            term_frequency = match_info[x]  # f(qi, D)
916            docs_with_term = match_info[x + 2]  # n(qi)
917
918            # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) )
919            idf = log(
920                    (total_docs - docs_with_term + 0.5) /
921                    (docs_with_term + 0.5))
922            if idf <= 0.0:
923                idf = 1e-6
924
925            doc_length = match_info[L_O + icol]  # |D|
926            avg_length = match_info[A_O + icol]  # avgdl
927            if avg_length == 0:
928                avg_length = 1
929            ratio = doc_length / avg_length
930
931            num = term_frequency * (K + 1)
932            b_part = 1 - B + (B * ratio)
933            denom = term_frequency + (K * b_part)
934
935            pc_score = idf * (num / denom)
936            score += (pc_score * weight)
937
938    free(weights)
939    return -1 * score
940
941
942def peewee_bm25f(py_match_info, *raw_weights):
943    # Usage: peewee_bm25f(matchinfo(table, 'pcnalx'), 1)
944    # where the second parameter is the index of the column and
945    # the 3rd and 4th specify k and b.
946    cdef:
947        unsigned int *match_info
948        bytes _match_info_buf = bytes(py_match_info)
949        char *match_info_buf = _match_info_buf
950        int nphrase, ncol
951        double B = 0.75, K = 1.2, epsilon
952        double total_docs, term_frequency, docs_with_term
953        double doc_length = 0.0, avg_length = 0.0
954        double idf, weight, ratio, num, b_part, denom, pc_score
955        double *weights
956        int P_O = 0, C_O = 1, N_O = 2, A_O = 3, L_O, X_O
957        int iphrase, icol, x
958        double score = 0.0
959
960    match_info = <unsigned int *>match_info_buf
961    nphrase = match_info[P_O]  # n
962    ncol = match_info[C_O]
963    total_docs = match_info[N_O]  # N
964
965    L_O = A_O + ncol
966    X_O = L_O + ncol
967
968    for icol in range(ncol):
969        avg_length += match_info[A_O + icol]
970        doc_length += match_info[L_O + icol]
971
972    epsilon = 1.0 / (total_docs * avg_length)
973    if avg_length == 0:
974        avg_length = 1
975    ratio = doc_length / avg_length
976    weights = get_weights(ncol, raw_weights)
977
978    for iphrase in range(nphrase):
979        for icol in range(ncol):
980            weight = weights[icol]
981            if weight == 0:
982                continue
983
984            x = X_O + (3 * (icol + iphrase * ncol))
985            term_frequency = match_info[x]  # f(qi, D)
986            docs_with_term = match_info[x + 2]  # n(qi)
987
988            # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) )
989            idf = log(
990                (total_docs - docs_with_term + 0.5) /
991                (docs_with_term + 0.5))
992            idf = epsilon if idf <= 0 else idf
993
994            num = term_frequency * (K + 1)
995            b_part = 1 - B + (B * ratio)
996            denom = term_frequency + (K * b_part)
997
998            pc_score = idf * ((num / denom) + 1.)
999            score += (pc_score * weight)
1000
1001    free(weights)
1002    return -1 * score
1003
1004
1005cdef uint32_t murmurhash2(const unsigned char *key, ssize_t nlen,
1006                          uint32_t seed):
1007    cdef:
1008        uint32_t m = 0x5bd1e995
1009        int r = 24
1010        const unsigned char *data = key
1011        uint32_t h = seed ^ nlen
1012        uint32_t k
1013
1014    while nlen >= 4:
1015        k = <uint32_t>((<uint32_t *>data)[0])
1016
1017        k *= m
1018        k = k ^ (k >> r)
1019        k *= m
1020
1021        h *= m
1022        h = h ^ k
1023
1024        data += 4
1025        nlen -= 4
1026
1027    if nlen == 3:
1028        h = h ^ (data[2] << 16)
1029    if nlen >= 2:
1030        h = h ^ (data[1] << 8)
1031    if nlen >= 1:
1032        h = h ^ (data[0])
1033        h *= m
1034
1035    h = h ^ (h >> 13)
1036    h *= m
1037    h = h ^ (h >> 15)
1038    return h
1039
1040
1041def peewee_murmurhash(key, seed=None):
1042    if key is None:
1043        return
1044
1045    cdef:
1046        bytes bkey = encode(key)
1047        int nseed = seed or 0
1048
1049    if key:
1050        return murmurhash2(<unsigned char *>bkey, len(bkey), nseed)
1051    return 0
1052
1053
1054def make_hash(hash_impl):
1055    def inner(*items):
1056        state = hash_impl()
1057        for item in items:
1058            state.update(encode(item))
1059        return state.hexdigest()
1060    return inner
1061
1062
1063peewee_md5 = make_hash(hashlib.md5)
1064peewee_sha1 = make_hash(hashlib.sha1)
1065peewee_sha256 = make_hash(hashlib.sha256)
1066
1067
1068def _register_functions(database, pairs):
1069    for func, name in pairs:
1070        database.register_function(func, name)
1071
1072
1073def register_hash_functions(database):
1074    _register_functions(database, (
1075        (peewee_murmurhash, 'murmurhash'),
1076        (peewee_md5, 'md5'),
1077        (peewee_sha1, 'sha1'),
1078        (peewee_sha256, 'sha256'),
1079        (zlib.adler32, 'adler32'),
1080        (zlib.crc32, 'crc32')))
1081
1082
1083def register_rank_functions(database):
1084    _register_functions(database, (
1085        (peewee_bm25, 'fts_bm25'),
1086        (peewee_bm25f, 'fts_bm25f'),
1087        (peewee_lucene, 'fts_lucene'),
1088        (peewee_rank, 'fts_rank')))
1089
1090
1091ctypedef struct bf_t:
1092    void *bits
1093    size_t size
1094
1095cdef int seeds[10]
1096seeds[:] = [0, 1337, 37, 0xabcd, 0xdead, 0xface, 97, 0xed11, 0xcad9, 0x827b]
1097
1098
1099cdef bf_t *bf_create(size_t size):
1100    cdef bf_t *bf = <bf_t *>calloc(1, sizeof(bf_t))
1101    bf.size = size
1102    bf.bits = calloc(1, size)
1103    return bf
1104
1105@cython.cdivision(True)
1106cdef uint32_t bf_bitindex(bf_t *bf, unsigned char *key, size_t klen, int seed):
1107    cdef:
1108        uint32_t h = murmurhash2(key, klen, seed)
1109    return h % (bf.size * 8)
1110
1111@cython.cdivision(True)
1112cdef bf_add(bf_t *bf, unsigned char *key):
1113    cdef:
1114        uint8_t *bits = <uint8_t *>(bf.bits)
1115        uint32_t h
1116        int pos, seed
1117        size_t keylen = strlen(<const char *>key)
1118
1119    for seed in seeds:
1120        h = bf_bitindex(bf, key, keylen, seed)
1121        pos = h / 8
1122        bits[pos] = bits[pos] | (1 << (h % 8))
1123
1124@cython.cdivision(True)
1125cdef int bf_contains(bf_t *bf, unsigned char *key):
1126    cdef:
1127        uint8_t *bits = <uint8_t *>(bf.bits)
1128        uint32_t h
1129        int pos, seed
1130        size_t keylen = strlen(<const char *>key)
1131
1132    for seed in seeds:
1133        h = bf_bitindex(bf, key, keylen, seed)
1134        pos = h / 8
1135        if not (bits[pos] & (1 << (h % 8))):
1136            return 0
1137    return 1
1138
1139cdef bf_free(bf_t *bf):
1140    free(bf.bits)
1141    free(bf)
1142
1143
1144cdef class BloomFilter(object):
1145    cdef:
1146        bf_t *bf
1147
1148    def __init__(self, size=1024 * 32):
1149        self.bf = bf_create(<size_t>size)
1150
1151    def __dealloc__(self):
1152        if self.bf:
1153            bf_free(self.bf)
1154
1155    def __len__(self):
1156        return self.bf.size
1157
1158    def add(self, *keys):
1159        cdef bytes bkey
1160
1161        for key in keys:
1162            bkey = encode(key)
1163            bf_add(self.bf, <unsigned char *>bkey)
1164
1165    def __contains__(self, key):
1166        cdef bytes bkey = encode(key)
1167        return bf_contains(self.bf, <unsigned char *>bkey)
1168
1169    def to_buffer(self):
1170        # We have to do this so that embedded NULL bytes are preserved.
1171        cdef bytes buf = PyBytes_FromStringAndSize(<char *>(self.bf.bits),
1172                                                   self.bf.size)
1173        # Similarly we wrap in a buffer object so pysqlite preserves the
1174        # embedded NULL bytes.
1175        return buf
1176
1177    @classmethod
1178    def from_buffer(cls, data):
1179        cdef:
1180            char *buf
1181            Py_ssize_t buflen
1182            BloomFilter bloom
1183
1184        PyBytes_AsStringAndSize(data, &buf, &buflen)
1185
1186        bloom = BloomFilter(buflen)
1187        memcpy(bloom.bf.bits, <void *>buf, buflen)
1188        return bloom
1189
1190    @classmethod
1191    def calculate_size(cls, double n, double p):
1192        cdef double m = ceil((n * log(p)) / log(1.0 / (pow(2.0, log(2.0)))))
1193        return m
1194
1195
1196cdef class BloomFilterAggregate(object):
1197    cdef:
1198        BloomFilter bf
1199
1200    def __init__(self):
1201        self.bf = None
1202
1203    def step(self, value, size=None):
1204        if not self.bf:
1205            size = size or 1024
1206            self.bf = BloomFilter(size)
1207
1208        self.bf.add(value)
1209
1210    def finalize(self):
1211        if not self.bf:
1212            return None
1213
1214        return pysqlite.Binary(self.bf.to_buffer())
1215
1216
1217def peewee_bloomfilter_contains(key, data):
1218    cdef:
1219        bf_t bf
1220        bytes bkey
1221        bytes bdata = bytes(data)
1222        unsigned char *cdata = <unsigned char *>bdata
1223
1224    bf.size = len(data)
1225    bf.bits = <void *>cdata
1226    bkey = encode(key)
1227
1228    return bf_contains(&bf, <unsigned char *>bkey)
1229
1230
1231def peewee_bloomfilter_calculate_size(n_items, error_p):
1232    return BloomFilter.calculate_size(n_items, error_p)
1233
1234
1235def register_bloomfilter(database):
1236    database.register_aggregate(BloomFilterAggregate, 'bloomfilter')
1237    database.register_function(peewee_bloomfilter_contains,
1238                               'bloomfilter_contains')
1239    database.register_function(peewee_bloomfilter_calculate_size,
1240                               'bloomfilter_calculate_size')
1241
1242
1243cdef inline int _check_connection(pysqlite_Connection *conn) except -1:
1244    """
1245    Check that the underlying SQLite database connection is usable. Raises an
1246    InterfaceError if the connection is either uninitialized or closed.
1247    """
1248    if not conn.db:
1249        raise InterfaceError('Cannot operate on closed database.')
1250    return 1
1251
1252
1253class ZeroBlob(Node):
1254    def __init__(self, length):
1255        if not isinstance(length, int) or length < 0:
1256            raise ValueError('Length must be a positive integer.')
1257        self.length = length
1258
1259    def __sql__(self, ctx):
1260        return ctx.literal('zeroblob(%s)' % self.length)
1261
1262
1263cdef class Blob(object)  # Forward declaration.
1264
1265
1266cdef inline int _check_blob_closed(Blob blob) except -1:
1267    if not blob.pBlob:
1268        raise InterfaceError('Cannot operate on closed blob.')
1269    return 1
1270
1271
1272cdef class Blob(object):
1273    cdef:
1274        int offset
1275        pysqlite_Connection *conn
1276        sqlite3_blob *pBlob
1277
1278    def __init__(self, database, table, column, rowid,
1279                 read_only=False):
1280        cdef:
1281            bytes btable = encode(table)
1282            bytes bcolumn = encode(column)
1283            int flags = 0 if read_only else 1
1284            int rc
1285            sqlite3_blob *blob
1286
1287        self.conn = <pysqlite_Connection *>(database._state.conn)
1288        _check_connection(self.conn)
1289
1290        rc = sqlite3_blob_open(
1291            self.conn.db,
1292            'main',
1293            <char *>btable,
1294            <char *>bcolumn,
1295            <long long>rowid,
1296            flags,
1297            &blob)
1298        if rc != SQLITE_OK:
1299            raise OperationalError('Unable to open blob.')
1300        if not blob:
1301            raise MemoryError('Unable to allocate blob.')
1302
1303        self.pBlob = blob
1304        self.offset = 0
1305
1306    cdef _close(self):
1307        if self.pBlob:
1308            sqlite3_blob_close(self.pBlob)
1309        self.pBlob = <sqlite3_blob *>0
1310
1311    def __dealloc__(self):
1312        self._close()
1313
1314    def __len__(self):
1315        _check_blob_closed(self)
1316        return sqlite3_blob_bytes(self.pBlob)
1317
1318    def read(self, n=None):
1319        cdef:
1320            bytes pybuf
1321            int length = -1
1322            int size
1323            char *buf
1324
1325        if n is not None:
1326            length = n
1327
1328        _check_blob_closed(self)
1329        size = sqlite3_blob_bytes(self.pBlob)
1330        if self.offset == size or length == 0:
1331            return b''
1332
1333        if length < 0:
1334            length = size - self.offset
1335
1336        if self.offset + length > size:
1337            length = size - self.offset
1338
1339        pybuf = PyBytes_FromStringAndSize(NULL, length)
1340        buf = PyBytes_AS_STRING(pybuf)
1341        if sqlite3_blob_read(self.pBlob, buf, length, self.offset):
1342            self._close()
1343            raise OperationalError('Error reading from blob.')
1344
1345        self.offset += length
1346        return bytes(pybuf)
1347
1348    def seek(self, offset, frame_of_reference=0):
1349        cdef int size
1350        _check_blob_closed(self)
1351        size = sqlite3_blob_bytes(self.pBlob)
1352        if frame_of_reference == 0:
1353            if offset < 0 or offset > size:
1354                raise ValueError('seek() offset outside of valid range.')
1355            self.offset = offset
1356        elif frame_of_reference == 1:
1357            if self.offset + offset < 0 or self.offset + offset > size:
1358                raise ValueError('seek() offset outside of valid range.')
1359            self.offset += offset
1360        elif frame_of_reference == 2:
1361            if size + offset < 0 or size + offset > size:
1362                raise ValueError('seek() offset outside of valid range.')
1363            self.offset = size + offset
1364        else:
1365            raise ValueError('seek() frame of reference must be 0, 1 or 2.')
1366
1367    def tell(self):
1368        _check_blob_closed(self)
1369        return self.offset
1370
1371    def write(self, bytes data):
1372        cdef:
1373            char *buf
1374            int size
1375            Py_ssize_t buflen
1376
1377        _check_blob_closed(self)
1378        size = sqlite3_blob_bytes(self.pBlob)
1379        PyBytes_AsStringAndSize(data, &buf, &buflen)
1380        if (<int>(buflen + self.offset)) < self.offset:
1381            raise ValueError('Data is too large (integer wrap)')
1382        if (<int>(buflen + self.offset)) > size:
1383            raise ValueError('Data would go beyond end of blob')
1384        if sqlite3_blob_write(self.pBlob, buf, buflen, self.offset):
1385            raise OperationalError('Error writing to blob.')
1386        self.offset += <int>buflen
1387
1388    def close(self):
1389        self._close()
1390
1391    def reopen(self, rowid):
1392        _check_blob_closed(self)
1393        self.offset = 0
1394        if sqlite3_blob_reopen(self.pBlob, <long long>rowid):
1395            self._close()
1396            raise OperationalError('Unable to re-open blob.')
1397
1398
1399def sqlite_get_status(flag):
1400    cdef:
1401        int current, highwater, rc
1402
1403    rc = sqlite3_status(flag, &current, &highwater, 0)
1404    if rc == SQLITE_OK:
1405        return (current, highwater)
1406    raise Exception('Error requesting status: %s' % rc)
1407
1408
1409def sqlite_get_db_status(conn, flag):
1410    cdef:
1411        int current, highwater, rc
1412        pysqlite_Connection *c_conn = <pysqlite_Connection *>conn
1413
1414    rc = sqlite3_db_status(c_conn.db, flag, &current, &highwater, 0)
1415    if rc == SQLITE_OK:
1416        return (current, highwater)
1417    raise Exception('Error requesting db status: %s' % rc)
1418
1419
1420cdef class ConnectionHelper(object):
1421    cdef:
1422        object _commit_hook, _rollback_hook, _update_hook
1423        pysqlite_Connection *conn
1424
1425    def __init__(self, connection):
1426        self.conn = <pysqlite_Connection *>connection
1427        self._commit_hook = self._rollback_hook = self._update_hook = None
1428
1429    def __dealloc__(self):
1430        # When deallocating a Database object, we need to ensure that we clear
1431        # any commit, rollback or update hooks that may have been applied.
1432        if not self.conn.initialized or not self.conn.db:
1433            return
1434
1435        if self._commit_hook is not None:
1436            sqlite3_commit_hook(self.conn.db, NULL, NULL)
1437        if self._rollback_hook is not None:
1438            sqlite3_rollback_hook(self.conn.db, NULL, NULL)
1439        if self._update_hook is not None:
1440            sqlite3_update_hook(self.conn.db, NULL, NULL)
1441
1442    def set_commit_hook(self, fn):
1443        self._commit_hook = fn
1444        if fn is None:
1445            sqlite3_commit_hook(self.conn.db, NULL, NULL)
1446        else:
1447            sqlite3_commit_hook(self.conn.db, _commit_callback, <void *>fn)
1448
1449    def set_rollback_hook(self, fn):
1450        self._rollback_hook = fn
1451        if fn is None:
1452            sqlite3_rollback_hook(self.conn.db, NULL, NULL)
1453        else:
1454            sqlite3_rollback_hook(self.conn.db, _rollback_callback, <void *>fn)
1455
1456    def set_update_hook(self, fn):
1457        self._update_hook = fn
1458        if fn is None:
1459            sqlite3_update_hook(self.conn.db, NULL, NULL)
1460        else:
1461            sqlite3_update_hook(self.conn.db, _update_callback, <void *>fn)
1462
1463    def set_busy_handler(self, timeout=5):
1464        """
1465        Replace the default busy handler with one that introduces some "jitter"
1466        into the amount of time delayed between checks.
1467        """
1468        cdef sqlite3_int64 n = timeout * 1000
1469        sqlite3_busy_handler(self.conn.db, _aggressive_busy_handler, <void *>n)
1470        return True
1471
1472    def changes(self):
1473        return sqlite3_changes(self.conn.db)
1474
1475    def last_insert_rowid(self):
1476        return <int>sqlite3_last_insert_rowid(self.conn.db)
1477
1478    def autocommit(self):
1479        return sqlite3_get_autocommit(self.conn.db) != 0
1480
1481
1482cdef int _commit_callback(void *userData) with gil:
1483    # C-callback that delegates to the Python commit handler. If the Python
1484    # function raises a ValueError, then the commit is aborted and the
1485    # transaction rolled back. Otherwise, regardless of the function return
1486    # value, the transaction will commit.
1487    cdef object fn = <object>userData
1488    try:
1489        fn()
1490    except ValueError:
1491        return 1
1492    else:
1493        return SQLITE_OK
1494
1495
1496cdef void _rollback_callback(void *userData) with gil:
1497    # C-callback that delegates to the Python rollback handler.
1498    cdef object fn = <object>userData
1499    fn()
1500
1501
1502cdef void _update_callback(void *userData, int queryType, const char *database,
1503                           const char *table, sqlite3_int64 rowid) with gil:
1504    # C-callback that delegates to a Python function that is executed whenever
1505    # the database is updated (insert/update/delete queries). The Python
1506    # callback receives a string indicating the query type, the name of the
1507    # database, the name of the table being updated, and the rowid of the row
1508    # being updatd.
1509    cdef object fn = <object>userData
1510    if queryType == SQLITE_INSERT:
1511        query = 'INSERT'
1512    elif queryType == SQLITE_UPDATE:
1513        query = 'UPDATE'
1514    elif queryType == SQLITE_DELETE:
1515        query = 'DELETE'
1516    else:
1517        query = ''
1518    fn(query, decode(database), decode(table), <int>rowid)
1519
1520
1521def backup(src_conn, dest_conn, pages=None, name=None, progress=None):
1522    cdef:
1523        bytes bname = encode(name or 'main')
1524        int page_step = pages or -1
1525        int rc
1526        pysqlite_Connection *src = <pysqlite_Connection *>src_conn
1527        pysqlite_Connection *dest = <pysqlite_Connection *>dest_conn
1528        sqlite3 *src_db = src.db
1529        sqlite3 *dest_db = dest.db
1530        sqlite3_backup *backup
1531
1532    # We always backup to the "main" database in the dest db.
1533    backup = sqlite3_backup_init(dest_db, b'main', src_db, bname)
1534    if backup == NULL:
1535        raise OperationalError('Unable to initialize backup.')
1536
1537    while True:
1538        with nogil:
1539            rc = sqlite3_backup_step(backup, page_step)
1540        if progress is not None:
1541            # Progress-handler is called with (remaining, page count, is done?)
1542            remaining = sqlite3_backup_remaining(backup)
1543            page_count = sqlite3_backup_pagecount(backup)
1544            try:
1545                progress(remaining, page_count, rc == SQLITE_DONE)
1546            except:
1547                sqlite3_backup_finish(backup)
1548                raise
1549        if rc == SQLITE_BUSY or rc == SQLITE_LOCKED:
1550            with nogil:
1551                sqlite3_sleep(250)
1552        elif rc == SQLITE_DONE:
1553            break
1554
1555    with nogil:
1556        sqlite3_backup_finish(backup)
1557    if sqlite3_errcode(dest_db):
1558        raise OperationalError('Error backuping up database: %s' %
1559                               sqlite3_errmsg(dest_db))
1560    return True
1561
1562
1563def backup_to_file(src_conn, filename, pages=None, name=None, progress=None):
1564    dest_conn = pysqlite.connect(filename)
1565    backup(src_conn, dest_conn, pages=pages, name=name, progress=progress)
1566    dest_conn.close()
1567    return True
1568
1569
1570cdef int _aggressive_busy_handler(void *ptr, int n) nogil:
1571    # In concurrent environments, it often seems that if multiple queries are
1572    # kicked off at around the same time, they proceed in lock-step to check
1573    # for the availability of the lock. By introducing some "jitter" we can
1574    # ensure that this doesn't happen. Furthermore, this function makes more
1575    # attempts in the same time period than the default handler.
1576    cdef:
1577        sqlite3_int64 busyTimeout = <sqlite3_int64>ptr
1578        int current, total
1579
1580    if n < 20:
1581        current = 25 - (rand() % 10)  # ~20ms
1582        total = n * 20
1583    elif n < 40:
1584        current = 50 - (rand() % 20)  # ~40ms
1585        total = 400 + ((n - 20) * 40)
1586    else:
1587        current = 120 - (rand() % 40)  # ~100ms
1588        total = 1200 + ((n - 40) * 100)  # Estimate the amount of time slept.
1589
1590    if total + current > busyTimeout:
1591        current = busyTimeout - total
1592    if current > 0:
1593        sqlite3_sleep(current)
1594        return 1
1595    return 0
1596