1from __future__ import absolute_import
2from pony.py23compat import PY2, imap, basestring, buffer, int_types, unicode
3
4import os.path, sys, re, json
5import sqlite3 as sqlite
6from decimal import Decimal
7from datetime import datetime, date, time, timedelta
8from random import random
9from time import strptime
10from threading import Lock
11from uuid import UUID
12from binascii import hexlify
13from functools import wraps
14
15from pony.orm import core, dbschema, dbapiprovider
16from pony.orm.core import log_orm
17from pony.orm.ormtypes import Json, TrackedArray
18from pony.orm.sqltranslation import SQLTranslator, StringExprMonad
19from pony.orm.sqlbuilding import SQLBuilder, Value, join, make_unary_func
20from pony.orm.dbapiprovider import DBAPIProvider, Pool, wrap_dbapi_exceptions
21from pony.utils import datetime2timestamp, timestamp2datetime, absolutize_path, localbase, throw, reraise, \
22    cut_traceback_depth
23
24class SqliteExtensionUnavailable(Exception):
25    pass
26
27NoneType = type(None)
28
29class SQLiteForeignKey(dbschema.ForeignKey):
30    def get_create_command(foreign_key):
31        assert False  # pragma: no cover
32
33class SQLiteSchema(dbschema.DBSchema):
34    dialect = 'SQLite'
35    named_foreign_keys = False
36    fk_class = SQLiteForeignKey
37
38def make_overriden_string_func(sqlop):
39    def func(translator, monad):
40        sql = monad.getsql()
41        assert len(sql) == 1
42        translator = monad.translator
43        return StringExprMonad(monad.type, [ sqlop, sql[0] ])
44    func.__name__ = sqlop
45    return func
46
47
48class SQLiteTranslator(SQLTranslator):
49    dialect = 'SQLite'
50    sqlite_version = sqlite.sqlite_version_info
51    row_value_syntax = False
52    rowid_support = True
53
54    StringMixin_UPPER = make_overriden_string_func('PY_UPPER')
55    StringMixin_LOWER = make_overriden_string_func('PY_LOWER')
56
57class SQLiteValue(Value):
58    __slots__ = []
59    def __unicode__(self):
60        value = self.value
61        if isinstance(value, datetime):
62            return self.quote_str(datetime2timestamp(value))
63        if isinstance(value, date):
64            return self.quote_str(str(value))
65        if isinstance(value, timedelta):
66            return repr(value.total_seconds() / (24 * 60 * 60))
67        return Value.__unicode__(self)
68    if not PY2: __str__ = __unicode__
69
70class SQLiteBuilder(SQLBuilder):
71    dialect = 'SQLite'
72    least_func_name = 'min'
73    greatest_func_name = 'max'
74    value_class = SQLiteValue
75    def __init__(builder, provider, ast):
76        builder.json1_available = provider.json1_available
77        SQLBuilder.__init__(builder, provider, ast)
78    def SELECT_FOR_UPDATE(builder, nowait, skip_locked, *sections):
79        assert not builder.indent
80        return builder.SELECT(*sections)
81    def INSERT(builder, table_name, columns, values, returning=None):
82        if not values: return 'INSERT INTO %s DEFAULT VALUES' % builder.quote_name(table_name)
83        return SQLBuilder.INSERT(builder, table_name, columns, values, returning)
84    def STRING_SLICE(builder, expr, start, stop):
85        if start is None:
86            start = [ 'VALUE', None ]
87        if stop is None:
88            stop = [ 'VALUE', None ]
89        return "py_string_slice(", builder(expr), ', ', builder(start), ', ', builder(stop), ")"
90    def IN(builder, expr1, x):
91        if not x:
92            return '0 = 1'
93        if len(x) >= 1 and x[0] == 'SELECT':
94            return builder(expr1), ' IN ', builder(x)
95        op = ' IN (VALUES ' if expr1[0] == 'ROW' else ' IN ('
96        expr_list = [ builder(expr) for expr in x ]
97        return builder(expr1), op, join(', ', expr_list), ')'
98    def NOT_IN(builder, expr1, x):
99        if not x:
100            return '1 = 1'
101        if len(x) >= 1 and x[0] == 'SELECT':
102            return builder(expr1), ' NOT IN ', builder(x)
103        op = ' NOT IN (VALUES ' if expr1[0] == 'ROW' else ' NOT IN ('
104        expr_list = [ builder(expr) for expr in x ]
105        return builder(expr1), op, join(', ', expr_list), ')'
106    def TODAY(builder):
107        return "date('now', 'localtime')"
108    def NOW(builder):
109        return "datetime('now', 'localtime')"
110    def YEAR(builder, expr):
111        return 'cast(substr(', builder(expr), ', 1, 4) as integer)'
112    def MONTH(builder, expr):
113        return 'cast(substr(', builder(expr), ', 6, 2) as integer)'
114    def DAY(builder, expr):
115        return 'cast(substr(', builder(expr), ', 9, 2) as integer)'
116    def HOUR(builder, expr):
117        return 'cast(substr(', builder(expr), ', 12, 2) as integer)'
118    def MINUTE(builder, expr):
119        return 'cast(substr(', builder(expr), ', 15, 2) as integer)'
120    def SECOND(builder, expr):
121        return 'cast(substr(', builder(expr), ', 18, 2) as integer)'
122    def datetime_add(builder, funcname, expr, td):
123        assert isinstance(td, timedelta)
124        modifiers = []
125        seconds = td.seconds + td.days * 24 * 3600
126        sign = '+' if seconds > 0 else '-'
127        seconds = abs(seconds)
128        if seconds >= (24 * 3600):
129            days = seconds // (24 * 3600)
130            modifiers.append(", '%s%d days'" % (sign, days))
131            seconds -= days * 24 * 3600
132        if seconds >= 3600:
133            hours = seconds // 3600
134            modifiers.append(", '%s%d hours'" % (sign, hours))
135            seconds -= hours * 3600
136        if seconds >= 60:
137            minutes = seconds // 60
138            modifiers.append(", '%s%d minutes'" % (sign, minutes))
139            seconds -= minutes * 60
140        if seconds:
141            modifiers.append(", '%s%d seconds'" % (sign, seconds))
142        if not modifiers: return builder(expr)
143        return funcname, '(', builder(expr), modifiers, ')'
144    def DATE_ADD(builder, expr, delta):
145        if delta[0] == 'VALUE' and isinstance(delta[1], timedelta):
146            return builder.datetime_add('date', expr, delta[1])
147        return 'datetime(julianday(', builder(expr), ') + ', builder(delta), ')'
148    def DATE_SUB(builder, expr, delta):
149        if delta[0] == 'VALUE' and isinstance(delta[1], timedelta):
150            return builder.datetime_add('date', expr, -delta[1])
151        return 'datetime(julianday(', builder(expr), ') - ', builder(delta), ')'
152    def DATE_DIFF(builder, expr1, expr2):
153        return 'julianday(', builder(expr1), ') - julianday(', builder(expr2), ')'
154    def DATETIME_ADD(builder, expr, delta):
155        if delta[0] == 'VALUE' and isinstance(delta[1], timedelta):
156            return builder.datetime_add('datetime', expr, delta[1])
157        return 'datetime(julianday(', builder(expr), ') + ', builder(delta), ')'
158    def DATETIME_SUB(builder, expr, delta):
159        if delta[0] == 'VALUE' and isinstance(delta[1], timedelta):
160            return builder.datetime_add('datetime', expr, -delta[1])
161        return 'datetime(julianday(', builder(expr), ') - ', builder(delta), ')'
162    def DATETIME_DIFF(builder, expr1, expr2):
163        return 'julianday(', builder(expr1), ') - julianday(', builder(expr2), ')'
164    def RANDOM(builder):
165        return 'rand()'  # return '(random() / 9223372036854775807.0 + 1.0) / 2.0'
166    PY_UPPER = make_unary_func('py_upper')
167    PY_LOWER = make_unary_func('py_lower')
168    def FLOAT_EQ(builder, a, b):
169        a, b = builder(a), builder(b)
170        return 'abs(', a, ' - ', b, ') / coalesce(nullif(max(abs(', a, '), abs(', b, ')), 0), 1) <= 1e-14'
171    def FLOAT_NE(builder, a, b):
172        a, b = builder(a), builder(b)
173        return 'abs(', a, ' - ', b, ') / coalesce(nullif(max(abs(', a, '), abs(', b, ')), 0), 1) > 1e-14'
174    def JSON_QUERY(builder, expr, path):
175        fname = 'json_extract' if builder.json1_available else 'py_json_extract'
176        path_sql, has_params, has_wildcards = builder.build_json_path(path)
177        return 'py_json_unwrap(', fname, '(', builder(expr), ', null, ', path_sql, '))'
178    json_value_type_mapping = {unicode: 'text', bool: 'integer', int: 'integer', float: 'real'}
179    def JSON_VALUE(builder, expr, path, type):
180        func_name = 'json_extract' if builder.json1_available else 'py_json_extract'
181        path_sql, has_params, has_wildcards = builder.build_json_path(path)
182        type_name = builder.json_value_type_mapping.get(type)
183        result = func_name, '(', builder(expr), ', ', path_sql, ')'
184        if type_name is not None: result = 'CAST(', result, ' as ', type_name, ')'
185        return result
186    def JSON_NONZERO(builder, expr):
187        return builder(expr), ''' NOT IN ('null', 'false', '0', '""', '[]', '{}')'''
188    def JSON_ARRAY_LENGTH(builder, value):
189        func_name = 'json_array_length' if builder.json1_available else 'py_json_array_length'
190        return func_name, '(', builder(value), ')'
191    def JSON_CONTAINS(builder, expr, path, key):
192        path_sql, has_params, has_wildcards = builder.build_json_path(path)
193        return 'py_json_contains(', builder(expr), ', ', path_sql, ',  ', builder(key), ')'
194    def ARRAY_INDEX(builder, col, index):
195        return 'py_array_index(', builder(col), ', ', builder(index), ')'
196    def ARRAY_CONTAINS(builder, key, not_in, col):
197        return ('NOT ' if not_in else ''), 'py_array_contains(', builder(col), ', ', builder(key), ')'
198    def ARRAY_SUBSET(builder, array1, not_in, array2):
199        return ('NOT ' if not_in else ''), 'py_array_subset(', builder(array2), ', ', builder(array1), ')'
200    def ARRAY_LENGTH(builder, array):
201        return 'py_array_length(', builder(array), ')'
202    def ARRAY_SLICE(builder, array, start, stop):
203        return 'py_array_slice(', builder(array), ', ', \
204               builder(start) if start else 'null', ',',\
205               builder(stop) if stop else 'null', ')'
206    def MAKE_ARRAY(builder, *items):
207        return 'py_make_array(', join(', ', (builder(item) for item in items)), ')'
208
209class SQLiteIntConverter(dbapiprovider.IntConverter):
210    def sql_type(converter):
211        attr = converter.attr
212        if attr is not None and attr.auto: return 'INTEGER'  # Only this type can have AUTOINCREMENT option
213        return dbapiprovider.IntConverter.sql_type(converter)
214
215class SQLiteDecimalConverter(dbapiprovider.DecimalConverter):
216    inf = Decimal('infinity')
217    neg_inf = Decimal('-infinity')
218    NaN = Decimal('NaN')
219    def sql2py(converter, val):
220        try: val = Decimal(str(val))
221        except: return val
222        exp = converter.exp
223        if exp is not None: val = val.quantize(exp)
224        return val
225    def py2sql(converter, val):
226        if type(val) is not Decimal: val = Decimal(val)
227        exp = converter.exp
228        if exp is not None:
229            if val in (converter.inf, converter.neg_inf, converter.NaN):
230                throw(ValueError, 'Cannot store %s Decimal value in database' % val)
231            val = val.quantize(exp)
232        return str(val)
233
234class SQLiteDateConverter(dbapiprovider.DateConverter):
235    def sql2py(converter, val):
236        try:
237            time_tuple = strptime(val[:10], '%Y-%m-%d')
238            return date(*time_tuple[:3])
239        except: return val
240    def py2sql(converter, val):
241        return val.strftime('%Y-%m-%d')
242
243class SQLiteTimeConverter(dbapiprovider.TimeConverter):
244    def sql2py(converter, val):
245        try:
246            if len(val) <= 8: dt = datetime.strptime(val, '%H:%M:%S')
247            else: dt = datetime.strptime(val, '%H:%M:%S.%f')
248            return dt.time()
249        except: return val
250    def py2sql(converter, val):
251        return val.isoformat()
252
253class SQLiteTimedeltaConverter(dbapiprovider.TimedeltaConverter):
254    def sql2py(converter, val):
255        return timedelta(days=val)
256    def py2sql(converter, val):
257        return val.days + (val.seconds + val.microseconds / 1000000.0) / 86400.0
258
259class SQLiteDatetimeConverter(dbapiprovider.DatetimeConverter):
260    def sql2py(converter, val):
261        try: return timestamp2datetime(val)
262        except: return val
263    def py2sql(converter, val):
264        return datetime2timestamp(val)
265
266class SQLiteJsonConverter(dbapiprovider.JsonConverter):
267    json_kwargs = {'separators': (',', ':'), 'sort_keys': True, 'ensure_ascii': False}
268
269def dumps(items):
270    return json.dumps(items, **SQLiteJsonConverter.json_kwargs)
271
272class SQLiteArrayConverter(dbapiprovider.ArrayConverter):
273    array_types = {
274        int: ('int', SQLiteIntConverter),
275        unicode: ('text', dbapiprovider.StrConverter),
276        float: ('real', dbapiprovider.RealConverter)
277    }
278
279    def dbval2val(converter, dbval, obj=None):
280        if not dbval: return None
281        items = json.loads(dbval)
282        if obj is None:
283            return items
284        return TrackedArray(obj, converter.attr, items)
285
286    def val2dbval(converter, val, obj=None):
287        return dumps(val)
288
289class LocalExceptions(localbase):
290    def __init__(self):
291        self.exc_info = None
292        self.keep_traceback = False
293
294local_exceptions = LocalExceptions()
295
296def keep_exception(func):
297    @wraps(func)
298    def new_func(*args):
299        local_exceptions.exc_info = None
300        try:
301            return func(*args)
302        except Exception:
303            local_exceptions.exc_info = sys.exc_info()
304            if not local_exceptions.keep_traceback:
305                local_exceptions.exc_info = local_exceptions.exc_info[:2] + (None,)
306            raise
307        finally:
308            local_exceptions.keep_traceback = False
309    return new_func
310
311
312class SQLiteProvider(DBAPIProvider):
313    dialect = 'SQLite'
314    local_exceptions = local_exceptions
315    max_name_len = 1024
316
317    dbapi_module = sqlite
318    dbschema_cls = SQLiteSchema
319    translator_cls = SQLiteTranslator
320    sqlbuilder_cls = SQLiteBuilder
321    array_converter_cls = SQLiteArrayConverter
322
323    name_before_table = 'db_name'
324
325    server_version = sqlite.sqlite_version_info
326
327    converter_classes = [
328        (NoneType, dbapiprovider.NoneConverter),
329        (bool, dbapiprovider.BoolConverter),
330        (basestring, dbapiprovider.StrConverter),
331        (int_types, SQLiteIntConverter),
332        (float, dbapiprovider.RealConverter),
333        (Decimal, SQLiteDecimalConverter),
334        (datetime, SQLiteDatetimeConverter),
335        (date, SQLiteDateConverter),
336        (time, SQLiteTimeConverter),
337        (timedelta, SQLiteTimedeltaConverter),
338        (UUID, dbapiprovider.UuidConverter),
339        (buffer, dbapiprovider.BlobConverter),
340        (Json, SQLiteJsonConverter)
341    ]
342
343    def __init__(provider, *args, **kwargs):
344        DBAPIProvider.__init__(provider, *args, **kwargs)
345        provider.pre_transaction_lock = Lock()
346        provider.transaction_lock = Lock()
347
348    @wrap_dbapi_exceptions
349    def inspect_connection(provider, conn):
350        DBAPIProvider.inspect_connection(provider, conn)
351        provider.json1_available = provider.check_json1(conn)
352
353    def restore_exception(provider):
354        if provider.local_exceptions.exc_info is not None:
355            try: reraise(*provider.local_exceptions.exc_info)
356            finally: provider.local_exceptions.exc_info = None
357
358    def acquire_lock(provider):
359        provider.pre_transaction_lock.acquire()
360        try:
361            provider.transaction_lock.acquire()
362        finally:
363            provider.pre_transaction_lock.release()
364
365    def release_lock(provider):
366        provider.transaction_lock.release()
367
368    @wrap_dbapi_exceptions
369    def set_transaction_mode(provider, connection, cache):
370        assert not cache.in_transaction
371        if cache.immediate:
372            provider.acquire_lock()
373        try:
374            cursor = connection.cursor()
375
376            db_session = cache.db_session
377            if db_session is not None and db_session.ddl:
378                cursor.execute('PRAGMA foreign_keys')
379                fk = cursor.fetchone()
380                if fk is not None: fk = fk[0]
381                if fk:
382                    sql = 'PRAGMA foreign_keys = false'
383                    if core.local.debug: log_orm(sql)
384                    cursor.execute(sql)
385                cache.saved_fk_state = bool(fk)
386                assert cache.immediate
387
388            if cache.immediate:
389                sql = 'BEGIN IMMEDIATE TRANSACTION'
390                if core.local.debug: log_orm(sql)
391                cursor.execute(sql)
392                cache.in_transaction = True
393            elif core.local.debug: log_orm('SWITCH TO AUTOCOMMIT MODE')
394        finally:
395            if cache.immediate and not cache.in_transaction:
396                provider.release_lock()
397
398    def commit(provider, connection, cache=None):
399        in_transaction = cache is not None and cache.in_transaction
400        try:
401            DBAPIProvider.commit(provider, connection, cache)
402        finally:
403            if in_transaction:
404                cache.in_transaction = False
405                provider.release_lock()
406
407    def rollback(provider, connection, cache=None):
408        in_transaction = cache is not None and cache.in_transaction
409        try:
410            DBAPIProvider.rollback(provider, connection, cache)
411        finally:
412            if in_transaction:
413                cache.in_transaction = False
414                provider.release_lock()
415
416    def drop(provider, connection, cache=None):
417        in_transaction = cache is not None and cache.in_transaction
418        try:
419            DBAPIProvider.drop(provider, connection, cache)
420        finally:
421            if in_transaction:
422                cache.in_transaction = False
423                provider.release_lock()
424
425    @wrap_dbapi_exceptions
426    def release(provider, connection, cache=None):
427        if cache is not None:
428            db_session = cache.db_session
429            if db_session is not None and db_session.ddl and cache.saved_fk_state:
430                try:
431                    cursor = connection.cursor()
432                    sql = 'PRAGMA foreign_keys = true'
433                    if core.local.debug: log_orm(sql)
434                    cursor.execute(sql)
435                except:
436                    provider.pool.drop(connection)
437                    raise
438        DBAPIProvider.release(provider, connection, cache)
439
440    def get_pool(provider, filename, create_db=False, **kwargs):
441        if filename != ':memory:':
442            # When relative filename is specified, it is considered
443            # not relative to cwd, but to user module where
444            # Database instance is created
445
446            # the list of frames:
447            # 7 - user code: db = Database(...)
448            # 6 - cut_traceback decorator wrapper
449            # 5 - cut_traceback decorator
450            # 4 - pony.orm.Database.__init__() / .bind()
451            # 3 - pony.orm.Database._bind()
452            # 2 - pony.dbapiprovider.DBAPIProvider.__init__()
453            # 1 - SQLiteProvider.__init__()
454            # 0 - pony.dbproviders.sqlite.get_pool()
455            filename = absolutize_path(filename, frame_depth=cut_traceback_depth+5)
456        return SQLitePool(filename, create_db, **kwargs)
457
458    def table_exists(provider, connection, table_name, case_sensitive=True):
459        return provider._exists(connection, table_name, None, case_sensitive)
460
461    def index_exists(provider, connection, table_name, index_name, case_sensitive=True):
462        return provider._exists(connection, table_name, index_name, case_sensitive)
463
464    def _exists(provider, connection, table_name, index_name=None, case_sensitive=True):
465        db_name, table_name = provider.split_table_name(table_name)
466
467        if db_name is None: catalog_name = 'sqlite_master'
468        else: catalog_name = (db_name, 'sqlite_master')
469        catalog_name = provider.quote_name(catalog_name)
470
471        cursor = connection.cursor()
472        if index_name is not None:
473            sql = "SELECT name FROM %s WHERE type='index' AND name=?" % catalog_name
474            if not case_sensitive: sql += ' COLLATE NOCASE'
475            cursor.execute(sql, [ index_name ])
476        else:
477            sql = "SELECT name FROM %s WHERE type='table' AND name=?" % catalog_name
478            if not case_sensitive: sql += ' COLLATE NOCASE'
479            cursor.execute(sql, [ table_name ])
480        row = cursor.fetchone()
481        return row[0] if row is not None else None
482
483    def fk_exists(provider, connection, table_name, fk_name):
484        assert False  # pragma: no cover
485
486    def check_json1(provider, connection):
487        cursor = connection.cursor()
488        sql = '''
489            select json('{"this": "is", "a": ["test"]}')'''
490        try:
491            cursor.execute(sql)
492            return True
493        except sqlite.OperationalError:
494            return False
495
496provider_cls = SQLiteProvider
497
498def _text_factory(s):
499    return s.decode('utf8', 'replace')
500
501def make_string_function(name, base_func):
502    def func(value):
503        if value is None:
504            return None
505        t = type(value)
506        if t is not unicode:
507            if t is buffer:
508                value = hexlify(value).decode('ascii')
509            else:
510                value = unicode(value)
511        result = base_func(value)
512        return result
513    func.__name__ = name
514    return func
515
516py_upper = make_string_function('py_upper', unicode.upper)
517py_lower = make_string_function('py_lower', unicode.lower)
518
519def py_json_unwrap(value):
520    # [null,some-value] -> some-value
521    if value is None:
522        return None
523    assert value.startswith('[null,'), value
524    return value[6:-1]
525
526path_cache = {}
527
528json_path_re = re.compile(r'\[(-?\d+)\]|\.(?:(\w+)|"([^"]*)")', re.UNICODE)
529
530def _parse_path(path):
531    if path in path_cache:
532        return path_cache[path]
533    keys = None
534    if isinstance(path, basestring) and path.startswith('$'):
535        keys = []
536        pos = 1
537        path_len = len(path)
538        while pos < path_len:
539            match = json_path_re.match(path, pos)
540            if match is not None:
541                g1, g2, g3 = match.groups()
542                keys.append(int(g1) if g1 else g2 or g3)
543                pos = match.end()
544            else:
545                keys = None
546                break
547        else: keys = tuple(keys)
548    path_cache[path] = keys
549    return keys
550
551def _traverse(obj, keys):
552    if keys is None: return None
553    list_or_dict = (list, dict)
554    for key in keys:
555        if type(obj) not in list_or_dict: return None
556        try: obj = obj[key]
557        except (KeyError, IndexError): return None
558    return obj
559
560def _extract(expr, *paths):
561    expr = json.loads(expr) if isinstance(expr, basestring) else expr
562    result = []
563    for path in paths:
564        keys = _parse_path(path)
565        result.append(_traverse(expr, keys))
566    return result[0] if len(paths) == 1 else result
567
568def py_json_extract(expr, *paths):
569    result = _extract(expr, *paths)
570    if type(result) in (list, dict):
571        result = json.dumps(result, **SQLiteJsonConverter.json_kwargs)
572    return result
573
574def py_json_query(expr, path, with_wrapper):
575    result = _extract(expr, path)
576    if type(result) not in (list, dict):
577        if not with_wrapper: return None
578        result = [result]
579    return json.dumps(result, **SQLiteJsonConverter.json_kwargs)
580
581def py_json_value(expr, path):
582    result = _extract(expr, path)
583    return result if type(result) not in (list, dict) else None
584
585def py_json_contains(expr, path, key):
586    expr = json.loads(expr) if isinstance(expr, basestring) else expr
587    keys = _parse_path(path)
588    expr = _traverse(expr, keys)
589    return type(expr) in (list, dict) and key in expr
590
591def py_json_nonzero(expr, path):
592    expr = json.loads(expr) if isinstance(expr, basestring) else expr
593    keys = _parse_path(path)
594    expr = _traverse(expr, keys)
595    return bool(expr)
596
597def py_json_array_length(expr, path=None):
598    expr = json.loads(expr) if isinstance(expr, basestring) else expr
599    if path:
600        keys = _parse_path(path)
601        expr = _traverse(expr, keys)
602    return len(expr) if type(expr) is list else 0
603
604def wrap_array_func(func):
605    @wraps(func)
606    def new_func(array, *args):
607        if array is None:
608            return None
609        array = json.loads(array)
610        return func(array, *args)
611    return new_func
612
613@wrap_array_func
614def py_array_index(array, index):
615    try:
616        return array[index]
617    except IndexError:
618        return None
619
620@wrap_array_func
621def py_array_contains(array, item):
622    return item in array
623
624@wrap_array_func
625def py_array_subset(array, items):
626    if items is None: return None
627    items = json.loads(items)
628    return set(items).issubset(set(array))
629
630@wrap_array_func
631def py_array_length(array):
632    return len(array)
633
634@wrap_array_func
635def py_array_slice(array, start, stop):
636    return dumps(array[start:stop])
637
638def py_make_array(*items):
639    return dumps(items)
640
641def py_string_slice(s, start, end):
642    if s is None:
643        return None
644    if isinstance(start, basestring):
645        start = int(start)
646    if isinstance(end, basestring):
647        end = int(end)
648    return s[start:end]
649
650class SQLitePool(Pool):
651    def __init__(pool, filename, create_db, **kwargs): # called separately in each thread
652        pool.filename = filename
653        pool.create_db = create_db
654        pool.kwargs = kwargs
655        pool.con = None
656    def _connect(pool):
657        filename = pool.filename
658        if filename != ':memory:' and not pool.create_db and not os.path.exists(filename):
659            throw(IOError, "Database file is not found: %r" % filename)
660        pool.con = con = sqlite.connect(filename, isolation_level=None, **pool.kwargs)
661        con.text_factory = _text_factory
662
663        def create_function(name, num_params, func):
664            func = keep_exception(func)
665            con.create_function(name, num_params, func)
666
667        create_function('power', 2, pow)
668        create_function('rand', 0, random)
669        create_function('py_upper', 1, py_upper)
670        create_function('py_lower', 1, py_lower)
671        create_function('py_json_unwrap', 1, py_json_unwrap)
672        create_function('py_json_extract', -1, py_json_extract)
673        create_function('py_json_contains', 3, py_json_contains)
674        create_function('py_json_nonzero', 2, py_json_nonzero)
675        create_function('py_json_array_length', -1, py_json_array_length)
676
677        create_function('py_array_index', 2, py_array_index)
678        create_function('py_array_contains', 2, py_array_contains)
679        create_function('py_array_subset', 2, py_array_subset)
680        create_function('py_array_length', 1, py_array_length)
681        create_function('py_array_slice', 3, py_array_slice)
682        create_function('py_make_array', -1, py_make_array)
683
684        create_function('py_string_slice', 3, py_string_slice)
685
686        if sqlite.sqlite_version_info >= (3, 6, 19):
687            con.execute('PRAGMA foreign_keys = true')
688
689        con.execute('PRAGMA case_sensitive_like = true')
690    def disconnect(pool):
691        if pool.filename != ':memory:':
692            Pool.disconnect(pool)
693    def drop(pool, con):
694        if pool.filename != ':memory:':
695            Pool.drop(pool, con)
696        else:
697            con.rollback()
698