1# -*- coding: utf-8 -*-
2# This file is part of beets.
3# Copyright 2016, Adrian Sampson.
4#
5# Permission is hereby granted, free of charge, to any person obtaining
6# a copy of this software and associated documentation files (the
7# "Software"), to deal in the Software without restriction, including
8# without limitation the rights to use, copy, modify, merge, publish,
9# distribute, sublicense, and/or sell copies of the Software, and to
10# permit persons to whom the Software is furnished to do so, subject to
11# the following conditions:
12#
13# The above copyright notice and this permission notice shall be
14# included in all copies or substantial portions of the Software.
15
16"""The central Model and Database constructs for DBCore.
17"""
18from __future__ import division, absolute_import, print_function
19
20import time
21import os
22from collections import defaultdict
23import threading
24import sqlite3
25import contextlib
26
27import beets
28from beets.util import functemplate
29from beets.util import py3_path
30from beets.dbcore import types
31from .query import MatchQuery, NullSort, TrueQuery
32import six
33if six.PY2:
34    from collections import Mapping
35else:
36    from collections.abc import Mapping
37
38
39class DBAccessError(Exception):
40    """The SQLite database became inaccessible.
41
42    This can happen when trying to read or write the database when, for
43    example, the database file is deleted or otherwise disappears. There
44    is probably no way to recover from this error.
45    """
46
47
48class FormattedMapping(Mapping):
49    """A `dict`-like formatted view of a model.
50
51    The accessor `mapping[key]` returns the formatted version of
52    `model[key]` as a unicode string.
53
54    If `for_path` is true, all path separators in the formatted values
55    are replaced.
56    """
57
58    def __init__(self, model, for_path=False):
59        self.for_path = for_path
60        self.model = model
61        self.model_keys = model.keys(True)
62
63    def __getitem__(self, key):
64        if key in self.model_keys:
65            return self._get_formatted(self.model, key)
66        else:
67            raise KeyError(key)
68
69    def __iter__(self):
70        return iter(self.model_keys)
71
72    def __len__(self):
73        return len(self.model_keys)
74
75    def get(self, key, default=None):
76        if default is None:
77            default = self.model._type(key).format(None)
78        return super(FormattedMapping, self).get(key, default)
79
80    def _get_formatted(self, model, key):
81        value = model._type(key).format(model.get(key))
82        if isinstance(value, bytes):
83            value = value.decode('utf-8', 'ignore')
84
85        if self.for_path:
86            sep_repl = beets.config['path_sep_replace'].as_str()
87            for sep in (os.path.sep, os.path.altsep):
88                if sep:
89                    value = value.replace(sep, sep_repl)
90
91        return value
92
93
94class LazyConvertDict(object):
95    """Lazily convert types for attributes fetched from the database
96    """
97
98    def __init__(self, model_cls):
99        """Initialize the object empty
100        """
101        self.data = {}
102        self.model_cls = model_cls
103        self._converted = {}
104
105    def init(self, data):
106        """Set the base data that should be lazily converted
107        """
108        self.data = data
109
110    def _convert(self, key, value):
111        """Convert the attribute type according the the SQL type
112        """
113        return self.model_cls._type(key).from_sql(value)
114
115    def __setitem__(self, key, value):
116        """Set an attribute value, assume it's already converted
117        """
118        self._converted[key] = value
119
120    def __getitem__(self, key):
121        """Get an attribute value, converting the type on demand
122        if needed
123        """
124        if key in self._converted:
125            return self._converted[key]
126        elif key in self.data:
127            value = self._convert(key, self.data[key])
128            self._converted[key] = value
129            return value
130
131    def __delitem__(self, key):
132        """Delete both converted and base data
133        """
134        if key in self._converted:
135            del self._converted[key]
136        if key in self.data:
137            del self.data[key]
138
139    def keys(self):
140        """Get a list of available field names for this object.
141        """
142        return list(self._converted.keys()) + list(self.data.keys())
143
144    def copy(self):
145        """Create a copy of the object.
146        """
147        new = self.__class__(self.model_cls)
148        new.data = self.data.copy()
149        new._converted = self._converted.copy()
150        return new
151
152    # Act like a dictionary.
153
154    def update(self, values):
155        """Assign all values in the given dict.
156        """
157        for key, value in values.items():
158            self[key] = value
159
160    def items(self):
161        """Iterate over (key, value) pairs that this object contains.
162        Computed fields are not included.
163        """
164        for key in self:
165            yield key, self[key]
166
167    def get(self, key, default=None):
168        """Get the value for a given key or `default` if it does not
169        exist.
170        """
171        if key in self:
172            return self[key]
173        else:
174            return default
175
176    def __contains__(self, key):
177        """Determine whether `key` is an attribute on this object.
178        """
179        return key in self.keys()
180
181    def __iter__(self):
182        """Iterate over the available field names (excluding computed
183        fields).
184        """
185        return iter(self.keys())
186
187
188# Abstract base for model classes.
189
190class Model(object):
191    """An abstract object representing an object in the database. Model
192    objects act like dictionaries (i.e., the allow subscript access like
193    ``obj['field']``). The same field set is available via attribute
194    access as a shortcut (i.e., ``obj.field``). Three kinds of attributes are
195    available:
196
197    * **Fixed attributes** come from a predetermined list of field
198      names. These fields correspond to SQLite table columns and are
199      thus fast to read, write, and query.
200    * **Flexible attributes** are free-form and do not need to be listed
201      ahead of time.
202    * **Computed attributes** are read-only fields computed by a getter
203      function provided by a plugin.
204
205    Access to all three field types is uniform: ``obj.field`` works the
206    same regardless of whether ``field`` is fixed, flexible, or
207    computed.
208
209    Model objects can optionally be associated with a `Library` object,
210    in which case they can be loaded and stored from the database. Dirty
211    flags are used to track which fields need to be stored.
212    """
213
214    # Abstract components (to be provided by subclasses).
215
216    _table = None
217    """The main SQLite table name.
218    """
219
220    _flex_table = None
221    """The flex field SQLite table name.
222    """
223
224    _fields = {}
225    """A mapping indicating available "fixed" fields on this type. The
226    keys are field names and the values are `Type` objects.
227    """
228
229    _search_fields = ()
230    """The fields that should be queried by default by unqualified query
231    terms.
232    """
233
234    _types = {}
235    """Optional Types for non-fixed (i.e., flexible and computed) fields.
236    """
237
238    _sorts = {}
239    """Optional named sort criteria. The keys are strings and the values
240    are subclasses of `Sort`.
241    """
242
243    _queries = {}
244    """Named queries that use a field-like `name:value` syntax but which
245    do not relate to any specific field.
246    """
247
248    _always_dirty = False
249    """By default, fields only become "dirty" when their value actually
250    changes. Enabling this flag marks fields as dirty even when the new
251    value is the same as the old value (e.g., `o.f = o.f`).
252    """
253
254    @classmethod
255    def _getters(cls):
256        """Return a mapping from field names to getter functions.
257        """
258        # We could cache this if it becomes a performance problem to
259        # gather the getter mapping every time.
260        raise NotImplementedError()
261
262    def _template_funcs(self):
263        """Return a mapping from function names to text-transformer
264        functions.
265        """
266        # As above: we could consider caching this result.
267        raise NotImplementedError()
268
269    # Basic operation.
270
271    def __init__(self, db=None, **values):
272        """Create a new object with an optional Database association and
273        initial field values.
274        """
275        self._db = db
276        self._dirty = set()
277        self._values_fixed = LazyConvertDict(self)
278        self._values_flex = LazyConvertDict(self)
279
280        # Initial contents.
281        self.update(values)
282        self.clear_dirty()
283
284    @classmethod
285    def _awaken(cls, db=None, fixed_values={}, flex_values={}):
286        """Create an object with values drawn from the database.
287
288        This is a performance optimization: the checks involved with
289        ordinary construction are bypassed.
290        """
291        obj = cls(db)
292
293        obj._values_fixed.init(fixed_values)
294        obj._values_flex.init(flex_values)
295
296        return obj
297
298    def __repr__(self):
299        return '{0}({1})'.format(
300            type(self).__name__,
301            ', '.join('{0}={1!r}'.format(k, v) for k, v in dict(self).items()),
302        )
303
304    def clear_dirty(self):
305        """Mark all fields as *clean* (i.e., not needing to be stored to
306        the database).
307        """
308        self._dirty = set()
309
310    def _check_db(self, need_id=True):
311        """Ensure that this object is associated with a database row: it
312        has a reference to a database (`_db`) and an id. A ValueError
313        exception is raised otherwise.
314        """
315        if not self._db:
316            raise ValueError(
317                u'{0} has no database'.format(type(self).__name__)
318            )
319        if need_id and not self.id:
320            raise ValueError(u'{0} has no id'.format(type(self).__name__))
321
322    def copy(self):
323        """Create a copy of the model object.
324
325        The field values and other state is duplicated, but the new copy
326        remains associated with the same database as the old object.
327        (A simple `copy.deepcopy` will not work because it would try to
328        duplicate the SQLite connection.)
329        """
330        new = self.__class__()
331        new._db = self._db
332        new._values_fixed = self._values_fixed.copy()
333        new._values_flex = self._values_flex.copy()
334        new._dirty = self._dirty.copy()
335        return new
336
337    # Essential field accessors.
338
339    @classmethod
340    def _type(cls, key):
341        """Get the type of a field, a `Type` instance.
342
343        If the field has no explicit type, it is given the base `Type`,
344        which does no conversion.
345        """
346        return cls._fields.get(key) or cls._types.get(key) or types.DEFAULT
347
348    def __getitem__(self, key):
349        """Get the value for a field. Raise a KeyError if the field is
350        not available.
351        """
352        getters = self._getters()
353        if key in getters:  # Computed.
354            return getters[key](self)
355        elif key in self._fields:  # Fixed.
356            if key in self._values_fixed:
357                return self._values_fixed[key]
358            else:
359                return self._type(key).null
360        elif key in self._values_flex:  # Flexible.
361            return self._values_flex[key]
362        else:
363            raise KeyError(key)
364
365    def _setitem(self, key, value):
366        """Assign the value for a field, return whether new and old value
367        differ.
368        """
369        # Choose where to place the value.
370        if key in self._fields:
371            source = self._values_fixed
372        else:
373            source = self._values_flex
374
375        # If the field has a type, filter the value.
376        value = self._type(key).normalize(value)
377
378        # Assign value and possibly mark as dirty.
379        old_value = source.get(key)
380        source[key] = value
381        changed = old_value != value
382        if self._always_dirty or changed:
383            self._dirty.add(key)
384
385        return changed
386
387    def __setitem__(self, key, value):
388        """Assign the value for a field.
389        """
390        self._setitem(key, value)
391
392    def __delitem__(self, key):
393        """Remove a flexible attribute from the model.
394        """
395        if key in self._values_flex:  # Flexible.
396            del self._values_flex[key]
397            self._dirty.add(key)  # Mark for dropping on store.
398        elif key in self._fields:  # Fixed
399            setattr(self, key, self._type(key).null)
400        elif key in self._getters():  # Computed.
401            raise KeyError(u'computed field {0} cannot be deleted'.format(key))
402        else:
403            raise KeyError(u'no such field {0}'.format(key))
404
405    def keys(self, computed=False):
406        """Get a list of available field names for this object. The
407        `computed` parameter controls whether computed (plugin-provided)
408        fields are included in the key list.
409        """
410        base_keys = list(self._fields) + list(self._values_flex.keys())
411        if computed:
412            return base_keys + list(self._getters().keys())
413        else:
414            return base_keys
415
416    @classmethod
417    def all_keys(cls):
418        """Get a list of available keys for objects of this type.
419        Includes fixed and computed fields.
420        """
421        return list(cls._fields) + list(cls._getters().keys())
422
423    # Act like a dictionary.
424
425    def update(self, values):
426        """Assign all values in the given dict.
427        """
428        for key, value in values.items():
429            self[key] = value
430
431    def items(self):
432        """Iterate over (key, value) pairs that this object contains.
433        Computed fields are not included.
434        """
435        for key in self:
436            yield key, self[key]
437
438    def get(self, key, default=None):
439        """Get the value for a given key or `default` if it does not
440        exist.
441        """
442        if key in self:
443            return self[key]
444        else:
445            return default
446
447    def __contains__(self, key):
448        """Determine whether `key` is an attribute on this object.
449        """
450        return key in self.keys(True)
451
452    def __iter__(self):
453        """Iterate over the available field names (excluding computed
454        fields).
455        """
456        return iter(self.keys())
457
458    # Convenient attribute access.
459
460    def __getattr__(self, key):
461        if key.startswith('_'):
462            raise AttributeError(u'model has no attribute {0!r}'.format(key))
463        else:
464            try:
465                return self[key]
466            except KeyError:
467                raise AttributeError(u'no such field {0!r}'.format(key))
468
469    def __setattr__(self, key, value):
470        if key.startswith('_'):
471            super(Model, self).__setattr__(key, value)
472        else:
473            self[key] = value
474
475    def __delattr__(self, key):
476        if key.startswith('_'):
477            super(Model, self).__delattr__(key)
478        else:
479            del self[key]
480
481    # Database interaction (CRUD methods).
482
483    def store(self, fields=None):
484        """Save the object's metadata into the library database.
485        :param fields: the fields to be stored. If not specified, all fields
486        will be.
487        """
488        if fields is None:
489            fields = self._fields
490        self._check_db()
491
492        # Build assignments for query.
493        assignments = []
494        subvars = []
495        for key in fields:
496            if key != 'id' and key in self._dirty:
497                self._dirty.remove(key)
498                assignments.append(key + '=?')
499                value = self._type(key).to_sql(self[key])
500                subvars.append(value)
501        assignments = ','.join(assignments)
502
503        with self._db.transaction() as tx:
504            # Main table update.
505            if assignments:
506                query = 'UPDATE {0} SET {1} WHERE id=?'.format(
507                    self._table, assignments
508                )
509                subvars.append(self.id)
510                tx.mutate(query, subvars)
511
512            # Modified/added flexible attributes.
513            for key, value in self._values_flex.items():
514                if key in self._dirty:
515                    self._dirty.remove(key)
516                    tx.mutate(
517                        'INSERT INTO {0} '
518                        '(entity_id, key, value) '
519                        'VALUES (?, ?, ?);'.format(self._flex_table),
520                        (self.id, key, value),
521                    )
522
523            # Deleted flexible attributes.
524            for key in self._dirty:
525                tx.mutate(
526                    'DELETE FROM {0} '
527                    'WHERE entity_id=? AND key=?'.format(self._flex_table),
528                    (self.id, key)
529                )
530
531        self.clear_dirty()
532
533    def load(self):
534        """Refresh the object's metadata from the library database.
535        """
536        self._check_db()
537        stored_obj = self._db._get(type(self), self.id)
538        assert stored_obj is not None, u"object {0} not in DB".format(self.id)
539        self._values_fixed = LazyConvertDict(self)
540        self._values_flex = LazyConvertDict(self)
541        self.update(dict(stored_obj))
542        self.clear_dirty()
543
544    def remove(self):
545        """Remove the object's associated rows from the database.
546        """
547        self._check_db()
548        with self._db.transaction() as tx:
549            tx.mutate(
550                'DELETE FROM {0} WHERE id=?'.format(self._table),
551                (self.id,)
552            )
553            tx.mutate(
554                'DELETE FROM {0} WHERE entity_id=?'.format(self._flex_table),
555                (self.id,)
556            )
557
558    def add(self, db=None):
559        """Add the object to the library database. This object must be
560        associated with a database; you can provide one via the `db`
561        parameter or use the currently associated database.
562
563        The object's `id` and `added` fields are set along with any
564        current field values.
565        """
566        if db:
567            self._db = db
568        self._check_db(False)
569
570        with self._db.transaction() as tx:
571            new_id = tx.mutate(
572                'INSERT INTO {0} DEFAULT VALUES'.format(self._table)
573            )
574            self.id = new_id
575            self.added = time.time()
576
577            # Mark every non-null field as dirty and store.
578            for key in self:
579                if self[key] is not None:
580                    self._dirty.add(key)
581            self.store()
582
583    # Formatting and templating.
584
585    _formatter = FormattedMapping
586
587    def formatted(self, for_path=False):
588        """Get a mapping containing all values on this object formatted
589        as human-readable unicode strings.
590        """
591        return self._formatter(self, for_path)
592
593    def evaluate_template(self, template, for_path=False):
594        """Evaluate a template (a string or a `Template` object) using
595        the object's fields. If `for_path` is true, then no new path
596        separators will be added to the template.
597        """
598        # Perform substitution.
599        if isinstance(template, six.string_types):
600            template = functemplate.template(template)
601        return template.substitute(self.formatted(for_path),
602                                   self._template_funcs())
603
604    # Parsing.
605
606    @classmethod
607    def _parse(cls, key, string):
608        """Parse a string as a value for the given key.
609        """
610        if not isinstance(string, six.string_types):
611            raise TypeError(u"_parse() argument must be a string")
612
613        return cls._type(key).parse(string)
614
615    def set_parse(self, key, string):
616        """Set the object's key to a value represented by a string.
617        """
618        self[key] = self._parse(key, string)
619
620
621# Database controller and supporting interfaces.
622
623class Results(object):
624    """An item query result set. Iterating over the collection lazily
625    constructs LibModel objects that reflect database rows.
626    """
627    def __init__(self, model_class, rows, db, flex_rows,
628                 query=None, sort=None):
629        """Create a result set that will construct objects of type
630        `model_class`.
631
632        `model_class` is a subclass of `LibModel` that will be
633        constructed. `rows` is a query result: a list of mappings. The
634        new objects will be associated with the database `db`.
635
636        If `query` is provided, it is used as a predicate to filter the
637        results for a "slow query" that cannot be evaluated by the
638        database directly. If `sort` is provided, it is used to sort the
639        full list of results before returning. This means it is a "slow
640        sort" and all objects must be built before returning the first
641        one.
642        """
643        self.model_class = model_class
644        self.rows = rows
645        self.db = db
646        self.query = query
647        self.sort = sort
648        self.flex_rows = flex_rows
649
650        # We keep a queue of rows we haven't yet consumed for
651        # materialization. We preserve the original total number of
652        # rows.
653        self._rows = rows
654        self._row_count = len(rows)
655
656        # The materialized objects corresponding to rows that have been
657        # consumed.
658        self._objects = []
659
660    def _get_objects(self):
661        """Construct and generate Model objects for they query. The
662        objects are returned in the order emitted from the database; no
663        slow sort is applied.
664
665        For performance, this generator caches materialized objects to
666        avoid constructing them more than once. This way, iterating over
667        a `Results` object a second time should be much faster than the
668        first.
669        """
670
671        # Index flexible attributes by the item ID, so we have easier access
672        flex_attrs = self._get_indexed_flex_attrs()
673
674        index = 0  # Position in the materialized objects.
675        while index < len(self._objects) or self._rows:
676            # Are there previously-materialized objects to produce?
677            if index < len(self._objects):
678                yield self._objects[index]
679                index += 1
680
681            # Otherwise, we consume another row, materialize its object
682            # and produce it.
683            else:
684                while self._rows:
685                    row = self._rows.pop(0)
686                    obj = self._make_model(row, flex_attrs.get(row['id'], {}))
687                    # If there is a slow-query predicate, ensurer that the
688                    # object passes it.
689                    if not self.query or self.query.match(obj):
690                        self._objects.append(obj)
691                        index += 1
692                        yield obj
693                        break
694
695    def __iter__(self):
696        """Construct and generate Model objects for all matching
697        objects, in sorted order.
698        """
699        if self.sort:
700            # Slow sort. Must build the full list first.
701            objects = self.sort.sort(list(self._get_objects()))
702            return iter(objects)
703
704        else:
705            # Objects are pre-sorted (i.e., by the database).
706            return self._get_objects()
707
708    def _get_indexed_flex_attrs(self):
709        """ Index flexible attributes by the entity id they belong to
710        """
711        flex_values = dict()
712        for row in self.flex_rows:
713            if row['entity_id'] not in flex_values:
714                flex_values[row['entity_id']] = dict()
715
716            flex_values[row['entity_id']][row['key']] = row['value']
717
718        return flex_values
719
720    def _make_model(self, row, flex_values={}):
721        """ Create a Model object for the given row
722        """
723        cols = dict(row)
724        values = dict((k, v) for (k, v) in cols.items()
725                      if not k[:4] == 'flex')
726
727        # Construct the Python object
728        obj = self.model_class._awaken(self.db, values, flex_values)
729        return obj
730
731    def __len__(self):
732        """Get the number of matching objects.
733        """
734        if not self._rows:
735            # Fully materialized. Just count the objects.
736            return len(self._objects)
737
738        elif self.query:
739            # A slow query. Fall back to testing every object.
740            count = 0
741            for obj in self:
742                count += 1
743            return count
744
745        else:
746            # A fast query. Just count the rows.
747            return self._row_count
748
749    def __nonzero__(self):
750        """Does this result contain any objects?
751        """
752        return self.__bool__()
753
754    def __bool__(self):
755        """Does this result contain any objects?
756        """
757        return bool(len(self))
758
759    def __getitem__(self, n):
760        """Get the nth item in this result set. This is inefficient: all
761        items up to n are materialized and thrown away.
762        """
763        if not self._rows and not self.sort:
764            # Fully materialized and already in order. Just look up the
765            # object.
766            return self._objects[n]
767
768        it = iter(self)
769        try:
770            for i in range(n):
771                next(it)
772            return next(it)
773        except StopIteration:
774            raise IndexError(u'result index {0} out of range'.format(n))
775
776    def get(self):
777        """Return the first matching object, or None if no objects
778        match.
779        """
780        it = iter(self)
781        try:
782            return next(it)
783        except StopIteration:
784            return None
785
786
787class Transaction(object):
788    """A context manager for safe, concurrent access to the database.
789    All SQL commands should be executed through a transaction.
790    """
791    def __init__(self, db):
792        self.db = db
793
794    def __enter__(self):
795        """Begin a transaction. This transaction may be created while
796        another is active in a different thread.
797        """
798        with self.db._tx_stack() as stack:
799            first = not stack
800            stack.append(self)
801        if first:
802            # Beginning a "root" transaction, which corresponds to an
803            # SQLite transaction.
804            self.db._db_lock.acquire()
805        return self
806
807    def __exit__(self, exc_type, exc_value, traceback):
808        """Complete a transaction. This must be the most recently
809        entered but not yet exited transaction. If it is the last active
810        transaction, the database updates are committed.
811        """
812        with self.db._tx_stack() as stack:
813            assert stack.pop() is self
814            empty = not stack
815        if empty:
816            # Ending a "root" transaction. End the SQLite transaction.
817            self.db._connection().commit()
818            self.db._db_lock.release()
819
820    def query(self, statement, subvals=()):
821        """Execute an SQL statement with substitution values and return
822        a list of rows from the database.
823        """
824        cursor = self.db._connection().execute(statement, subvals)
825        return cursor.fetchall()
826
827    def mutate(self, statement, subvals=()):
828        """Execute an SQL statement with substitution values and return
829        the row ID of the last affected row.
830        """
831        try:
832            cursor = self.db._connection().execute(statement, subvals)
833            return cursor.lastrowid
834        except sqlite3.OperationalError as e:
835            # In two specific cases, SQLite reports an error while accessing
836            # the underlying database file. We surface these exceptions as
837            # DBAccessError so the application can abort.
838            if e.args[0] in ("attempt to write a readonly database",
839                             "unable to open database file"):
840                raise DBAccessError(e.args[0])
841            else:
842                raise
843
844    def script(self, statements):
845        """Execute a string containing multiple SQL statements."""
846        self.db._connection().executescript(statements)
847
848
849class Database(object):
850    """A container for Model objects that wraps an SQLite database as
851    the backend.
852    """
853
854    _models = ()
855    """The Model subclasses representing tables in this database.
856    """
857
858    supports_extensions = hasattr(sqlite3.Connection, 'enable_load_extension')
859    """Whether or not the current version of SQLite supports extensions"""
860
861    def __init__(self, path, timeout=5.0):
862        self.path = path
863        self.timeout = timeout
864
865        self._connections = {}
866        self._tx_stacks = defaultdict(list)
867        self._extensions = []
868
869        # A lock to protect the _connections and _tx_stacks maps, which
870        # both map thread IDs to private resources.
871        self._shared_map_lock = threading.Lock()
872
873        # A lock to protect access to the database itself. SQLite does
874        # allow multiple threads to access the database at the same
875        # time, but many users were experiencing crashes related to this
876        # capability: where SQLite was compiled without HAVE_USLEEP, its
877        # backoff algorithm in the case of contention was causing
878        # whole-second sleeps (!) that would trigger its internal
879        # timeout. Using this lock ensures only one SQLite transaction
880        # is active at a time.
881        self._db_lock = threading.Lock()
882
883        # Set up database schema.
884        for model_cls in self._models:
885            self._make_table(model_cls._table, model_cls._fields)
886            self._make_attribute_table(model_cls._flex_table)
887
888    # Primitive access control: connections and transactions.
889
890    def _connection(self):
891        """Get a SQLite connection object to the underlying database.
892        One connection object is created per thread.
893        """
894        thread_id = threading.current_thread().ident
895        with self._shared_map_lock:
896            if thread_id in self._connections:
897                return self._connections[thread_id]
898            else:
899                conn = self._create_connection()
900                self._connections[thread_id] = conn
901                return conn
902
903    def _create_connection(self):
904        """Create a SQLite connection to the underlying database.
905
906        Makes a new connection every time. If you need to configure the
907        connection settings (e.g., add custom functions), override this
908        method.
909        """
910        # Make a new connection. The `sqlite3` module can't use
911        # bytestring paths here on Python 3, so we need to
912        # provide a `str` using `py3_path`.
913        conn = sqlite3.connect(
914            py3_path(self.path), timeout=self.timeout
915        )
916
917        if self.supports_extensions:
918            conn.enable_load_extension(True)
919
920            # Load any extension that are already loaded for other connections.
921            for path in self._extensions:
922                conn.load_extension(path)
923
924        # Access SELECT results like dictionaries.
925        conn.row_factory = sqlite3.Row
926        return conn
927
928    def _close(self):
929        """Close the all connections to the underlying SQLite database
930        from all threads. This does not render the database object
931        unusable; new connections can still be opened on demand.
932        """
933        with self._shared_map_lock:
934            self._connections.clear()
935
936    @contextlib.contextmanager
937    def _tx_stack(self):
938        """A context manager providing access to the current thread's
939        transaction stack. The context manager synchronizes access to
940        the stack map. Transactions should never migrate across threads.
941        """
942        thread_id = threading.current_thread().ident
943        with self._shared_map_lock:
944            yield self._tx_stacks[thread_id]
945
946    def transaction(self):
947        """Get a :class:`Transaction` object for interacting directly
948        with the underlying SQLite database.
949        """
950        return Transaction(self)
951
952    def load_extension(self, path):
953        """Load an SQLite extension into all open connections."""
954        if not self.supports_extensions:
955            raise ValueError(
956                    'this sqlite3 installation does not support extensions')
957
958        self._extensions.append(path)
959
960        # Load the extension into every open connection.
961        for conn in self._connections.values():
962            conn.load_extension(path)
963
964    # Schema setup and migration.
965
966    def _make_table(self, table, fields):
967        """Set up the schema of the database. `fields` is a mapping
968        from field names to `Type`s. Columns are added if necessary.
969        """
970        # Get current schema.
971        with self.transaction() as tx:
972            rows = tx.query('PRAGMA table_info(%s)' % table)
973        current_fields = set([row[1] for row in rows])
974
975        field_names = set(fields.keys())
976        if current_fields.issuperset(field_names):
977            # Table exists and has all the required columns.
978            return
979
980        if not current_fields:
981            # No table exists.
982            columns = []
983            for name, typ in fields.items():
984                columns.append('{0} {1}'.format(name, typ.sql))
985            setup_sql = 'CREATE TABLE {0} ({1});\n'.format(table,
986                                                           ', '.join(columns))
987
988        else:
989            # Table exists does not match the field set.
990            setup_sql = ''
991            for name, typ in fields.items():
992                if name in current_fields:
993                    continue
994                setup_sql += 'ALTER TABLE {0} ADD COLUMN {1} {2};\n'.format(
995                    table, name, typ.sql
996                )
997
998        with self.transaction() as tx:
999            tx.script(setup_sql)
1000
1001    def _make_attribute_table(self, flex_table):
1002        """Create a table and associated index for flexible attributes
1003        for the given entity (if they don't exist).
1004        """
1005        with self.transaction() as tx:
1006            tx.script("""
1007                CREATE TABLE IF NOT EXISTS {0} (
1008                    id INTEGER PRIMARY KEY,
1009                    entity_id INTEGER,
1010                    key TEXT,
1011                    value TEXT,
1012                    UNIQUE(entity_id, key) ON CONFLICT REPLACE);
1013                CREATE INDEX IF NOT EXISTS {0}_by_entity
1014                    ON {0} (entity_id);
1015                """.format(flex_table))
1016
1017    # Querying.
1018
1019    def _fetch(self, model_cls, query=None, sort=None):
1020        """Fetch the objects of type `model_cls` matching the given
1021        query. The query may be given as a string, string sequence, a
1022        Query object, or None (to fetch everything). `sort` is an
1023        `Sort` object.
1024        """
1025        query = query or TrueQuery()  # A null query.
1026        sort = sort or NullSort()  # Unsorted.
1027        where, subvals = query.clause()
1028        order_by = sort.order_clause()
1029
1030        sql = ("SELECT * FROM {0} WHERE {1} {2}").format(
1031            model_cls._table,
1032            where or '1',
1033            "ORDER BY {0}".format(order_by) if order_by else '',
1034        )
1035
1036        # Fetch flexible attributes for items matching the main query.
1037        # Doing the per-item filtering in python is faster than issuing
1038        # one query per item to sqlite.
1039        flex_sql = ("""
1040            SELECT * FROM {0} WHERE entity_id IN
1041                (SELECT id FROM {1} WHERE {2});
1042            """.format(
1043                model_cls._flex_table,
1044                model_cls._table,
1045                where or '1',
1046            )
1047        )
1048
1049        with self.transaction() as tx:
1050            rows = tx.query(sql, subvals)
1051            flex_rows = tx.query(flex_sql, subvals)
1052
1053        return Results(
1054            model_cls, rows, self, flex_rows,
1055            None if where else query,  # Slow query component.
1056            sort if sort.is_slow() else None,  # Slow sort component.
1057        )
1058
1059    def _get(self, model_cls, id):
1060        """Get a Model object by its id or None if the id does not
1061        exist.
1062        """
1063        return self._fetch(model_cls, MatchQuery('id', id)).get()
1064