1# -*- coding: utf-8 -*- 2# This file is part of beets. 3# Copyright 2016, Adrian Sampson. 4# 5# Permission is hereby granted, free of charge, to any person obtaining 6# a copy of this software and associated documentation files (the 7# "Software"), to deal in the Software without restriction, including 8# without limitation the rights to use, copy, modify, merge, publish, 9# distribute, sublicense, and/or sell copies of the Software, and to 10# permit persons to whom the Software is furnished to do so, subject to 11# the following conditions: 12# 13# The above copyright notice and this permission notice shall be 14# included in all copies or substantial portions of the Software. 15 16"""The central Model and Database constructs for DBCore. 17""" 18from __future__ import division, absolute_import, print_function 19 20import time 21import os 22from collections import defaultdict 23import threading 24import sqlite3 25import contextlib 26 27import beets 28from beets.util import functemplate 29from beets.util import py3_path 30from beets.dbcore import types 31from .query import MatchQuery, NullSort, TrueQuery 32import six 33if six.PY2: 34 from collections import Mapping 35else: 36 from collections.abc import Mapping 37 38 39class DBAccessError(Exception): 40 """The SQLite database became inaccessible. 41 42 This can happen when trying to read or write the database when, for 43 example, the database file is deleted or otherwise disappears. There 44 is probably no way to recover from this error. 45 """ 46 47 48class FormattedMapping(Mapping): 49 """A `dict`-like formatted view of a model. 50 51 The accessor `mapping[key]` returns the formatted version of 52 `model[key]` as a unicode string. 53 54 If `for_path` is true, all path separators in the formatted values 55 are replaced. 56 """ 57 58 def __init__(self, model, for_path=False): 59 self.for_path = for_path 60 self.model = model 61 self.model_keys = model.keys(True) 62 63 def __getitem__(self, key): 64 if key in self.model_keys: 65 return self._get_formatted(self.model, key) 66 else: 67 raise KeyError(key) 68 69 def __iter__(self): 70 return iter(self.model_keys) 71 72 def __len__(self): 73 return len(self.model_keys) 74 75 def get(self, key, default=None): 76 if default is None: 77 default = self.model._type(key).format(None) 78 return super(FormattedMapping, self).get(key, default) 79 80 def _get_formatted(self, model, key): 81 value = model._type(key).format(model.get(key)) 82 if isinstance(value, bytes): 83 value = value.decode('utf-8', 'ignore') 84 85 if self.for_path: 86 sep_repl = beets.config['path_sep_replace'].as_str() 87 for sep in (os.path.sep, os.path.altsep): 88 if sep: 89 value = value.replace(sep, sep_repl) 90 91 return value 92 93 94class LazyConvertDict(object): 95 """Lazily convert types for attributes fetched from the database 96 """ 97 98 def __init__(self, model_cls): 99 """Initialize the object empty 100 """ 101 self.data = {} 102 self.model_cls = model_cls 103 self._converted = {} 104 105 def init(self, data): 106 """Set the base data that should be lazily converted 107 """ 108 self.data = data 109 110 def _convert(self, key, value): 111 """Convert the attribute type according the the SQL type 112 """ 113 return self.model_cls._type(key).from_sql(value) 114 115 def __setitem__(self, key, value): 116 """Set an attribute value, assume it's already converted 117 """ 118 self._converted[key] = value 119 120 def __getitem__(self, key): 121 """Get an attribute value, converting the type on demand 122 if needed 123 """ 124 if key in self._converted: 125 return self._converted[key] 126 elif key in self.data: 127 value = self._convert(key, self.data[key]) 128 self._converted[key] = value 129 return value 130 131 def __delitem__(self, key): 132 """Delete both converted and base data 133 """ 134 if key in self._converted: 135 del self._converted[key] 136 if key in self.data: 137 del self.data[key] 138 139 def keys(self): 140 """Get a list of available field names for this object. 141 """ 142 return list(self._converted.keys()) + list(self.data.keys()) 143 144 def copy(self): 145 """Create a copy of the object. 146 """ 147 new = self.__class__(self.model_cls) 148 new.data = self.data.copy() 149 new._converted = self._converted.copy() 150 return new 151 152 # Act like a dictionary. 153 154 def update(self, values): 155 """Assign all values in the given dict. 156 """ 157 for key, value in values.items(): 158 self[key] = value 159 160 def items(self): 161 """Iterate over (key, value) pairs that this object contains. 162 Computed fields are not included. 163 """ 164 for key in self: 165 yield key, self[key] 166 167 def get(self, key, default=None): 168 """Get the value for a given key or `default` if it does not 169 exist. 170 """ 171 if key in self: 172 return self[key] 173 else: 174 return default 175 176 def __contains__(self, key): 177 """Determine whether `key` is an attribute on this object. 178 """ 179 return key in self.keys() 180 181 def __iter__(self): 182 """Iterate over the available field names (excluding computed 183 fields). 184 """ 185 return iter(self.keys()) 186 187 188# Abstract base for model classes. 189 190class Model(object): 191 """An abstract object representing an object in the database. Model 192 objects act like dictionaries (i.e., the allow subscript access like 193 ``obj['field']``). The same field set is available via attribute 194 access as a shortcut (i.e., ``obj.field``). Three kinds of attributes are 195 available: 196 197 * **Fixed attributes** come from a predetermined list of field 198 names. These fields correspond to SQLite table columns and are 199 thus fast to read, write, and query. 200 * **Flexible attributes** are free-form and do not need to be listed 201 ahead of time. 202 * **Computed attributes** are read-only fields computed by a getter 203 function provided by a plugin. 204 205 Access to all three field types is uniform: ``obj.field`` works the 206 same regardless of whether ``field`` is fixed, flexible, or 207 computed. 208 209 Model objects can optionally be associated with a `Library` object, 210 in which case they can be loaded and stored from the database. Dirty 211 flags are used to track which fields need to be stored. 212 """ 213 214 # Abstract components (to be provided by subclasses). 215 216 _table = None 217 """The main SQLite table name. 218 """ 219 220 _flex_table = None 221 """The flex field SQLite table name. 222 """ 223 224 _fields = {} 225 """A mapping indicating available "fixed" fields on this type. The 226 keys are field names and the values are `Type` objects. 227 """ 228 229 _search_fields = () 230 """The fields that should be queried by default by unqualified query 231 terms. 232 """ 233 234 _types = {} 235 """Optional Types for non-fixed (i.e., flexible and computed) fields. 236 """ 237 238 _sorts = {} 239 """Optional named sort criteria. The keys are strings and the values 240 are subclasses of `Sort`. 241 """ 242 243 _queries = {} 244 """Named queries that use a field-like `name:value` syntax but which 245 do not relate to any specific field. 246 """ 247 248 _always_dirty = False 249 """By default, fields only become "dirty" when their value actually 250 changes. Enabling this flag marks fields as dirty even when the new 251 value is the same as the old value (e.g., `o.f = o.f`). 252 """ 253 254 @classmethod 255 def _getters(cls): 256 """Return a mapping from field names to getter functions. 257 """ 258 # We could cache this if it becomes a performance problem to 259 # gather the getter mapping every time. 260 raise NotImplementedError() 261 262 def _template_funcs(self): 263 """Return a mapping from function names to text-transformer 264 functions. 265 """ 266 # As above: we could consider caching this result. 267 raise NotImplementedError() 268 269 # Basic operation. 270 271 def __init__(self, db=None, **values): 272 """Create a new object with an optional Database association and 273 initial field values. 274 """ 275 self._db = db 276 self._dirty = set() 277 self._values_fixed = LazyConvertDict(self) 278 self._values_flex = LazyConvertDict(self) 279 280 # Initial contents. 281 self.update(values) 282 self.clear_dirty() 283 284 @classmethod 285 def _awaken(cls, db=None, fixed_values={}, flex_values={}): 286 """Create an object with values drawn from the database. 287 288 This is a performance optimization: the checks involved with 289 ordinary construction are bypassed. 290 """ 291 obj = cls(db) 292 293 obj._values_fixed.init(fixed_values) 294 obj._values_flex.init(flex_values) 295 296 return obj 297 298 def __repr__(self): 299 return '{0}({1})'.format( 300 type(self).__name__, 301 ', '.join('{0}={1!r}'.format(k, v) for k, v in dict(self).items()), 302 ) 303 304 def clear_dirty(self): 305 """Mark all fields as *clean* (i.e., not needing to be stored to 306 the database). 307 """ 308 self._dirty = set() 309 310 def _check_db(self, need_id=True): 311 """Ensure that this object is associated with a database row: it 312 has a reference to a database (`_db`) and an id. A ValueError 313 exception is raised otherwise. 314 """ 315 if not self._db: 316 raise ValueError( 317 u'{0} has no database'.format(type(self).__name__) 318 ) 319 if need_id and not self.id: 320 raise ValueError(u'{0} has no id'.format(type(self).__name__)) 321 322 def copy(self): 323 """Create a copy of the model object. 324 325 The field values and other state is duplicated, but the new copy 326 remains associated with the same database as the old object. 327 (A simple `copy.deepcopy` will not work because it would try to 328 duplicate the SQLite connection.) 329 """ 330 new = self.__class__() 331 new._db = self._db 332 new._values_fixed = self._values_fixed.copy() 333 new._values_flex = self._values_flex.copy() 334 new._dirty = self._dirty.copy() 335 return new 336 337 # Essential field accessors. 338 339 @classmethod 340 def _type(cls, key): 341 """Get the type of a field, a `Type` instance. 342 343 If the field has no explicit type, it is given the base `Type`, 344 which does no conversion. 345 """ 346 return cls._fields.get(key) or cls._types.get(key) or types.DEFAULT 347 348 def __getitem__(self, key): 349 """Get the value for a field. Raise a KeyError if the field is 350 not available. 351 """ 352 getters = self._getters() 353 if key in getters: # Computed. 354 return getters[key](self) 355 elif key in self._fields: # Fixed. 356 if key in self._values_fixed: 357 return self._values_fixed[key] 358 else: 359 return self._type(key).null 360 elif key in self._values_flex: # Flexible. 361 return self._values_flex[key] 362 else: 363 raise KeyError(key) 364 365 def _setitem(self, key, value): 366 """Assign the value for a field, return whether new and old value 367 differ. 368 """ 369 # Choose where to place the value. 370 if key in self._fields: 371 source = self._values_fixed 372 else: 373 source = self._values_flex 374 375 # If the field has a type, filter the value. 376 value = self._type(key).normalize(value) 377 378 # Assign value and possibly mark as dirty. 379 old_value = source.get(key) 380 source[key] = value 381 changed = old_value != value 382 if self._always_dirty or changed: 383 self._dirty.add(key) 384 385 return changed 386 387 def __setitem__(self, key, value): 388 """Assign the value for a field. 389 """ 390 self._setitem(key, value) 391 392 def __delitem__(self, key): 393 """Remove a flexible attribute from the model. 394 """ 395 if key in self._values_flex: # Flexible. 396 del self._values_flex[key] 397 self._dirty.add(key) # Mark for dropping on store. 398 elif key in self._fields: # Fixed 399 setattr(self, key, self._type(key).null) 400 elif key in self._getters(): # Computed. 401 raise KeyError(u'computed field {0} cannot be deleted'.format(key)) 402 else: 403 raise KeyError(u'no such field {0}'.format(key)) 404 405 def keys(self, computed=False): 406 """Get a list of available field names for this object. The 407 `computed` parameter controls whether computed (plugin-provided) 408 fields are included in the key list. 409 """ 410 base_keys = list(self._fields) + list(self._values_flex.keys()) 411 if computed: 412 return base_keys + list(self._getters().keys()) 413 else: 414 return base_keys 415 416 @classmethod 417 def all_keys(cls): 418 """Get a list of available keys for objects of this type. 419 Includes fixed and computed fields. 420 """ 421 return list(cls._fields) + list(cls._getters().keys()) 422 423 # Act like a dictionary. 424 425 def update(self, values): 426 """Assign all values in the given dict. 427 """ 428 for key, value in values.items(): 429 self[key] = value 430 431 def items(self): 432 """Iterate over (key, value) pairs that this object contains. 433 Computed fields are not included. 434 """ 435 for key in self: 436 yield key, self[key] 437 438 def get(self, key, default=None): 439 """Get the value for a given key or `default` if it does not 440 exist. 441 """ 442 if key in self: 443 return self[key] 444 else: 445 return default 446 447 def __contains__(self, key): 448 """Determine whether `key` is an attribute on this object. 449 """ 450 return key in self.keys(True) 451 452 def __iter__(self): 453 """Iterate over the available field names (excluding computed 454 fields). 455 """ 456 return iter(self.keys()) 457 458 # Convenient attribute access. 459 460 def __getattr__(self, key): 461 if key.startswith('_'): 462 raise AttributeError(u'model has no attribute {0!r}'.format(key)) 463 else: 464 try: 465 return self[key] 466 except KeyError: 467 raise AttributeError(u'no such field {0!r}'.format(key)) 468 469 def __setattr__(self, key, value): 470 if key.startswith('_'): 471 super(Model, self).__setattr__(key, value) 472 else: 473 self[key] = value 474 475 def __delattr__(self, key): 476 if key.startswith('_'): 477 super(Model, self).__delattr__(key) 478 else: 479 del self[key] 480 481 # Database interaction (CRUD methods). 482 483 def store(self, fields=None): 484 """Save the object's metadata into the library database. 485 :param fields: the fields to be stored. If not specified, all fields 486 will be. 487 """ 488 if fields is None: 489 fields = self._fields 490 self._check_db() 491 492 # Build assignments for query. 493 assignments = [] 494 subvars = [] 495 for key in fields: 496 if key != 'id' and key in self._dirty: 497 self._dirty.remove(key) 498 assignments.append(key + '=?') 499 value = self._type(key).to_sql(self[key]) 500 subvars.append(value) 501 assignments = ','.join(assignments) 502 503 with self._db.transaction() as tx: 504 # Main table update. 505 if assignments: 506 query = 'UPDATE {0} SET {1} WHERE id=?'.format( 507 self._table, assignments 508 ) 509 subvars.append(self.id) 510 tx.mutate(query, subvars) 511 512 # Modified/added flexible attributes. 513 for key, value in self._values_flex.items(): 514 if key in self._dirty: 515 self._dirty.remove(key) 516 tx.mutate( 517 'INSERT INTO {0} ' 518 '(entity_id, key, value) ' 519 'VALUES (?, ?, ?);'.format(self._flex_table), 520 (self.id, key, value), 521 ) 522 523 # Deleted flexible attributes. 524 for key in self._dirty: 525 tx.mutate( 526 'DELETE FROM {0} ' 527 'WHERE entity_id=? AND key=?'.format(self._flex_table), 528 (self.id, key) 529 ) 530 531 self.clear_dirty() 532 533 def load(self): 534 """Refresh the object's metadata from the library database. 535 """ 536 self._check_db() 537 stored_obj = self._db._get(type(self), self.id) 538 assert stored_obj is not None, u"object {0} not in DB".format(self.id) 539 self._values_fixed = LazyConvertDict(self) 540 self._values_flex = LazyConvertDict(self) 541 self.update(dict(stored_obj)) 542 self.clear_dirty() 543 544 def remove(self): 545 """Remove the object's associated rows from the database. 546 """ 547 self._check_db() 548 with self._db.transaction() as tx: 549 tx.mutate( 550 'DELETE FROM {0} WHERE id=?'.format(self._table), 551 (self.id,) 552 ) 553 tx.mutate( 554 'DELETE FROM {0} WHERE entity_id=?'.format(self._flex_table), 555 (self.id,) 556 ) 557 558 def add(self, db=None): 559 """Add the object to the library database. This object must be 560 associated with a database; you can provide one via the `db` 561 parameter or use the currently associated database. 562 563 The object's `id` and `added` fields are set along with any 564 current field values. 565 """ 566 if db: 567 self._db = db 568 self._check_db(False) 569 570 with self._db.transaction() as tx: 571 new_id = tx.mutate( 572 'INSERT INTO {0} DEFAULT VALUES'.format(self._table) 573 ) 574 self.id = new_id 575 self.added = time.time() 576 577 # Mark every non-null field as dirty and store. 578 for key in self: 579 if self[key] is not None: 580 self._dirty.add(key) 581 self.store() 582 583 # Formatting and templating. 584 585 _formatter = FormattedMapping 586 587 def formatted(self, for_path=False): 588 """Get a mapping containing all values on this object formatted 589 as human-readable unicode strings. 590 """ 591 return self._formatter(self, for_path) 592 593 def evaluate_template(self, template, for_path=False): 594 """Evaluate a template (a string or a `Template` object) using 595 the object's fields. If `for_path` is true, then no new path 596 separators will be added to the template. 597 """ 598 # Perform substitution. 599 if isinstance(template, six.string_types): 600 template = functemplate.template(template) 601 return template.substitute(self.formatted(for_path), 602 self._template_funcs()) 603 604 # Parsing. 605 606 @classmethod 607 def _parse(cls, key, string): 608 """Parse a string as a value for the given key. 609 """ 610 if not isinstance(string, six.string_types): 611 raise TypeError(u"_parse() argument must be a string") 612 613 return cls._type(key).parse(string) 614 615 def set_parse(self, key, string): 616 """Set the object's key to a value represented by a string. 617 """ 618 self[key] = self._parse(key, string) 619 620 621# Database controller and supporting interfaces. 622 623class Results(object): 624 """An item query result set. Iterating over the collection lazily 625 constructs LibModel objects that reflect database rows. 626 """ 627 def __init__(self, model_class, rows, db, flex_rows, 628 query=None, sort=None): 629 """Create a result set that will construct objects of type 630 `model_class`. 631 632 `model_class` is a subclass of `LibModel` that will be 633 constructed. `rows` is a query result: a list of mappings. The 634 new objects will be associated with the database `db`. 635 636 If `query` is provided, it is used as a predicate to filter the 637 results for a "slow query" that cannot be evaluated by the 638 database directly. If `sort` is provided, it is used to sort the 639 full list of results before returning. This means it is a "slow 640 sort" and all objects must be built before returning the first 641 one. 642 """ 643 self.model_class = model_class 644 self.rows = rows 645 self.db = db 646 self.query = query 647 self.sort = sort 648 self.flex_rows = flex_rows 649 650 # We keep a queue of rows we haven't yet consumed for 651 # materialization. We preserve the original total number of 652 # rows. 653 self._rows = rows 654 self._row_count = len(rows) 655 656 # The materialized objects corresponding to rows that have been 657 # consumed. 658 self._objects = [] 659 660 def _get_objects(self): 661 """Construct and generate Model objects for they query. The 662 objects are returned in the order emitted from the database; no 663 slow sort is applied. 664 665 For performance, this generator caches materialized objects to 666 avoid constructing them more than once. This way, iterating over 667 a `Results` object a second time should be much faster than the 668 first. 669 """ 670 671 # Index flexible attributes by the item ID, so we have easier access 672 flex_attrs = self._get_indexed_flex_attrs() 673 674 index = 0 # Position in the materialized objects. 675 while index < len(self._objects) or self._rows: 676 # Are there previously-materialized objects to produce? 677 if index < len(self._objects): 678 yield self._objects[index] 679 index += 1 680 681 # Otherwise, we consume another row, materialize its object 682 # and produce it. 683 else: 684 while self._rows: 685 row = self._rows.pop(0) 686 obj = self._make_model(row, flex_attrs.get(row['id'], {})) 687 # If there is a slow-query predicate, ensurer that the 688 # object passes it. 689 if not self.query or self.query.match(obj): 690 self._objects.append(obj) 691 index += 1 692 yield obj 693 break 694 695 def __iter__(self): 696 """Construct and generate Model objects for all matching 697 objects, in sorted order. 698 """ 699 if self.sort: 700 # Slow sort. Must build the full list first. 701 objects = self.sort.sort(list(self._get_objects())) 702 return iter(objects) 703 704 else: 705 # Objects are pre-sorted (i.e., by the database). 706 return self._get_objects() 707 708 def _get_indexed_flex_attrs(self): 709 """ Index flexible attributes by the entity id they belong to 710 """ 711 flex_values = dict() 712 for row in self.flex_rows: 713 if row['entity_id'] not in flex_values: 714 flex_values[row['entity_id']] = dict() 715 716 flex_values[row['entity_id']][row['key']] = row['value'] 717 718 return flex_values 719 720 def _make_model(self, row, flex_values={}): 721 """ Create a Model object for the given row 722 """ 723 cols = dict(row) 724 values = dict((k, v) for (k, v) in cols.items() 725 if not k[:4] == 'flex') 726 727 # Construct the Python object 728 obj = self.model_class._awaken(self.db, values, flex_values) 729 return obj 730 731 def __len__(self): 732 """Get the number of matching objects. 733 """ 734 if not self._rows: 735 # Fully materialized. Just count the objects. 736 return len(self._objects) 737 738 elif self.query: 739 # A slow query. Fall back to testing every object. 740 count = 0 741 for obj in self: 742 count += 1 743 return count 744 745 else: 746 # A fast query. Just count the rows. 747 return self._row_count 748 749 def __nonzero__(self): 750 """Does this result contain any objects? 751 """ 752 return self.__bool__() 753 754 def __bool__(self): 755 """Does this result contain any objects? 756 """ 757 return bool(len(self)) 758 759 def __getitem__(self, n): 760 """Get the nth item in this result set. This is inefficient: all 761 items up to n are materialized and thrown away. 762 """ 763 if not self._rows and not self.sort: 764 # Fully materialized and already in order. Just look up the 765 # object. 766 return self._objects[n] 767 768 it = iter(self) 769 try: 770 for i in range(n): 771 next(it) 772 return next(it) 773 except StopIteration: 774 raise IndexError(u'result index {0} out of range'.format(n)) 775 776 def get(self): 777 """Return the first matching object, or None if no objects 778 match. 779 """ 780 it = iter(self) 781 try: 782 return next(it) 783 except StopIteration: 784 return None 785 786 787class Transaction(object): 788 """A context manager for safe, concurrent access to the database. 789 All SQL commands should be executed through a transaction. 790 """ 791 def __init__(self, db): 792 self.db = db 793 794 def __enter__(self): 795 """Begin a transaction. This transaction may be created while 796 another is active in a different thread. 797 """ 798 with self.db._tx_stack() as stack: 799 first = not stack 800 stack.append(self) 801 if first: 802 # Beginning a "root" transaction, which corresponds to an 803 # SQLite transaction. 804 self.db._db_lock.acquire() 805 return self 806 807 def __exit__(self, exc_type, exc_value, traceback): 808 """Complete a transaction. This must be the most recently 809 entered but not yet exited transaction. If it is the last active 810 transaction, the database updates are committed. 811 """ 812 with self.db._tx_stack() as stack: 813 assert stack.pop() is self 814 empty = not stack 815 if empty: 816 # Ending a "root" transaction. End the SQLite transaction. 817 self.db._connection().commit() 818 self.db._db_lock.release() 819 820 def query(self, statement, subvals=()): 821 """Execute an SQL statement with substitution values and return 822 a list of rows from the database. 823 """ 824 cursor = self.db._connection().execute(statement, subvals) 825 return cursor.fetchall() 826 827 def mutate(self, statement, subvals=()): 828 """Execute an SQL statement with substitution values and return 829 the row ID of the last affected row. 830 """ 831 try: 832 cursor = self.db._connection().execute(statement, subvals) 833 return cursor.lastrowid 834 except sqlite3.OperationalError as e: 835 # In two specific cases, SQLite reports an error while accessing 836 # the underlying database file. We surface these exceptions as 837 # DBAccessError so the application can abort. 838 if e.args[0] in ("attempt to write a readonly database", 839 "unable to open database file"): 840 raise DBAccessError(e.args[0]) 841 else: 842 raise 843 844 def script(self, statements): 845 """Execute a string containing multiple SQL statements.""" 846 self.db._connection().executescript(statements) 847 848 849class Database(object): 850 """A container for Model objects that wraps an SQLite database as 851 the backend. 852 """ 853 854 _models = () 855 """The Model subclasses representing tables in this database. 856 """ 857 858 supports_extensions = hasattr(sqlite3.Connection, 'enable_load_extension') 859 """Whether or not the current version of SQLite supports extensions""" 860 861 def __init__(self, path, timeout=5.0): 862 self.path = path 863 self.timeout = timeout 864 865 self._connections = {} 866 self._tx_stacks = defaultdict(list) 867 self._extensions = [] 868 869 # A lock to protect the _connections and _tx_stacks maps, which 870 # both map thread IDs to private resources. 871 self._shared_map_lock = threading.Lock() 872 873 # A lock to protect access to the database itself. SQLite does 874 # allow multiple threads to access the database at the same 875 # time, but many users were experiencing crashes related to this 876 # capability: where SQLite was compiled without HAVE_USLEEP, its 877 # backoff algorithm in the case of contention was causing 878 # whole-second sleeps (!) that would trigger its internal 879 # timeout. Using this lock ensures only one SQLite transaction 880 # is active at a time. 881 self._db_lock = threading.Lock() 882 883 # Set up database schema. 884 for model_cls in self._models: 885 self._make_table(model_cls._table, model_cls._fields) 886 self._make_attribute_table(model_cls._flex_table) 887 888 # Primitive access control: connections and transactions. 889 890 def _connection(self): 891 """Get a SQLite connection object to the underlying database. 892 One connection object is created per thread. 893 """ 894 thread_id = threading.current_thread().ident 895 with self._shared_map_lock: 896 if thread_id in self._connections: 897 return self._connections[thread_id] 898 else: 899 conn = self._create_connection() 900 self._connections[thread_id] = conn 901 return conn 902 903 def _create_connection(self): 904 """Create a SQLite connection to the underlying database. 905 906 Makes a new connection every time. If you need to configure the 907 connection settings (e.g., add custom functions), override this 908 method. 909 """ 910 # Make a new connection. The `sqlite3` module can't use 911 # bytestring paths here on Python 3, so we need to 912 # provide a `str` using `py3_path`. 913 conn = sqlite3.connect( 914 py3_path(self.path), timeout=self.timeout 915 ) 916 917 if self.supports_extensions: 918 conn.enable_load_extension(True) 919 920 # Load any extension that are already loaded for other connections. 921 for path in self._extensions: 922 conn.load_extension(path) 923 924 # Access SELECT results like dictionaries. 925 conn.row_factory = sqlite3.Row 926 return conn 927 928 def _close(self): 929 """Close the all connections to the underlying SQLite database 930 from all threads. This does not render the database object 931 unusable; new connections can still be opened on demand. 932 """ 933 with self._shared_map_lock: 934 self._connections.clear() 935 936 @contextlib.contextmanager 937 def _tx_stack(self): 938 """A context manager providing access to the current thread's 939 transaction stack. The context manager synchronizes access to 940 the stack map. Transactions should never migrate across threads. 941 """ 942 thread_id = threading.current_thread().ident 943 with self._shared_map_lock: 944 yield self._tx_stacks[thread_id] 945 946 def transaction(self): 947 """Get a :class:`Transaction` object for interacting directly 948 with the underlying SQLite database. 949 """ 950 return Transaction(self) 951 952 def load_extension(self, path): 953 """Load an SQLite extension into all open connections.""" 954 if not self.supports_extensions: 955 raise ValueError( 956 'this sqlite3 installation does not support extensions') 957 958 self._extensions.append(path) 959 960 # Load the extension into every open connection. 961 for conn in self._connections.values(): 962 conn.load_extension(path) 963 964 # Schema setup and migration. 965 966 def _make_table(self, table, fields): 967 """Set up the schema of the database. `fields` is a mapping 968 from field names to `Type`s. Columns are added if necessary. 969 """ 970 # Get current schema. 971 with self.transaction() as tx: 972 rows = tx.query('PRAGMA table_info(%s)' % table) 973 current_fields = set([row[1] for row in rows]) 974 975 field_names = set(fields.keys()) 976 if current_fields.issuperset(field_names): 977 # Table exists and has all the required columns. 978 return 979 980 if not current_fields: 981 # No table exists. 982 columns = [] 983 for name, typ in fields.items(): 984 columns.append('{0} {1}'.format(name, typ.sql)) 985 setup_sql = 'CREATE TABLE {0} ({1});\n'.format(table, 986 ', '.join(columns)) 987 988 else: 989 # Table exists does not match the field set. 990 setup_sql = '' 991 for name, typ in fields.items(): 992 if name in current_fields: 993 continue 994 setup_sql += 'ALTER TABLE {0} ADD COLUMN {1} {2};\n'.format( 995 table, name, typ.sql 996 ) 997 998 with self.transaction() as tx: 999 tx.script(setup_sql) 1000 1001 def _make_attribute_table(self, flex_table): 1002 """Create a table and associated index for flexible attributes 1003 for the given entity (if they don't exist). 1004 """ 1005 with self.transaction() as tx: 1006 tx.script(""" 1007 CREATE TABLE IF NOT EXISTS {0} ( 1008 id INTEGER PRIMARY KEY, 1009 entity_id INTEGER, 1010 key TEXT, 1011 value TEXT, 1012 UNIQUE(entity_id, key) ON CONFLICT REPLACE); 1013 CREATE INDEX IF NOT EXISTS {0}_by_entity 1014 ON {0} (entity_id); 1015 """.format(flex_table)) 1016 1017 # Querying. 1018 1019 def _fetch(self, model_cls, query=None, sort=None): 1020 """Fetch the objects of type `model_cls` matching the given 1021 query. The query may be given as a string, string sequence, a 1022 Query object, or None (to fetch everything). `sort` is an 1023 `Sort` object. 1024 """ 1025 query = query or TrueQuery() # A null query. 1026 sort = sort or NullSort() # Unsorted. 1027 where, subvals = query.clause() 1028 order_by = sort.order_clause() 1029 1030 sql = ("SELECT * FROM {0} WHERE {1} {2}").format( 1031 model_cls._table, 1032 where or '1', 1033 "ORDER BY {0}".format(order_by) if order_by else '', 1034 ) 1035 1036 # Fetch flexible attributes for items matching the main query. 1037 # Doing the per-item filtering in python is faster than issuing 1038 # one query per item to sqlite. 1039 flex_sql = (""" 1040 SELECT * FROM {0} WHERE entity_id IN 1041 (SELECT id FROM {1} WHERE {2}); 1042 """.format( 1043 model_cls._flex_table, 1044 model_cls._table, 1045 where or '1', 1046 ) 1047 ) 1048 1049 with self.transaction() as tx: 1050 rows = tx.query(sql, subvals) 1051 flex_rows = tx.query(flex_sql, subvals) 1052 1053 return Results( 1054 model_cls, rows, self, flex_rows, 1055 None if where else query, # Slow query component. 1056 sort if sort.is_slow() else None, # Slow sort component. 1057 ) 1058 1059 def _get(self, model_cls, id): 1060 """Get a Model object by its id or None if the id does not 1061 exist. 1062 """ 1063 return self._fetch(model_cls, MatchQuery('id', id)).get() 1064