1# vim: ft=python fileencoding=utf-8 sts=4 sw=4 et:
2
3# Copyright 2016-2021 Ryan Roden-Corrent (rcorre) <ryan@rcorre.net>
4#
5# This file is part of qutebrowser.
6#
7# qutebrowser is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# qutebrowser is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with qutebrowser.  If not, see <https://www.gnu.org/licenses/>.
19
20"""Provides access to sqlite databases."""
21
22import collections
23import contextlib
24import dataclasses
25import types
26from typing import Any, Dict, Iterator, List, Mapping, MutableSequence, Optional, Type
27
28from PyQt5.QtCore import QObject, pyqtSignal
29from PyQt5.QtSql import QSqlDatabase, QSqlError, QSqlQuery
30
31from qutebrowser.qt import sip
32from qutebrowser.utils import debug, log
33
34
35@dataclasses.dataclass
36class UserVersion:
37
38    """The version of data stored in the history database.
39
40    When we originally started using user_version, we only used it to signify that the
41    completion database should be regenerated. However, sometimes there are
42    backwards-incompatible changes.
43
44    Instead, we now (ab)use the fact that the user_version in sqlite is a 32-bit integer
45    to store both a major and a minor part. If only the minor part changed, we can deal
46    with it (there are only new URLs to clean up or somesuch). If the major part
47    changed, there are backwards-incompatible changes in how the database works, so
48    newer databases are not compatible with older qutebrowser versions.
49    """
50
51    major: int
52    minor: int
53
54    @classmethod
55    def from_int(cls, num: int) -> 'UserVersion':
56        """Parse a number from sqlite into a major/minor user version."""
57        assert 0 <= num <= 0x7FFF_FFFF, num  # signed integer, but shouldn't be negative
58        major = (num & 0x7FFF_0000) >> 16
59        minor = num & 0x0000_FFFF
60        return cls(major, minor)
61
62    def to_int(self) -> int:
63        """Get a sqlite integer from a major/minor user version."""
64        assert 0 <= self.major <= 0x7FFF  # signed integer
65        assert 0 <= self.minor <= 0xFFFF
66        return self.major << 16 | self.minor
67
68    def __str__(self) -> str:
69        return f'{self.major}.{self.minor}'
70
71
72class SqliteErrorCode:
73
74    """Error codes as used by sqlite.
75
76    See https://sqlite.org/rescode.html - note we only define the codes we use
77    in qutebrowser here.
78    """
79
80    ERROR = '1'  # generic error code
81    BUSY = '5'  # database is locked
82    READONLY = '8'  # attempt to write a readonly database
83    IOERR = '10'  # disk I/O error
84    CORRUPT = '11'  # database disk image is malformed
85    FULL = '13'  # database or disk is full
86    CANTOPEN = '14'  # unable to open database file
87    PROTOCOL = '15'  # locking protocol error
88    CONSTRAINT = '19'  # UNIQUE constraint failed
89    NOTADB = '26'  # file is not a database
90
91
92class Error(Exception):
93
94    """Base class for all SQL related errors."""
95
96    def __init__(self, msg: str, error: Optional[QSqlError] = None) -> None:
97        super().__init__(msg)
98        self.error = error
99
100    def text(self) -> str:
101        """Get a short text description of the error.
102
103        This is a string suitable to show to the user as error message.
104        """
105        if self.error is None:
106            return str(self)
107        else:
108            return self.error.databaseText()
109
110
111class KnownError(Error):
112
113    """Raised on an error interacting with the SQL database.
114
115    This is raised in conditions resulting from the environment (like a full
116    disk or I/O errors), where qutebrowser isn't to blame.
117    """
118
119
120class BugError(Error):
121
122    """Raised on an error interacting with the SQL database.
123
124    This is raised for errors resulting from a qutebrowser bug.
125    """
126
127
128def raise_sqlite_error(msg: str, error: QSqlError) -> None:
129    """Raise either a BugError or KnownError."""
130    error_code = error.nativeErrorCode()
131    database_text = error.databaseText()
132    driver_text = error.driverText()
133
134    log.sql.debug("SQL error:")
135    log.sql.debug(f"type: {debug.qenum_key(QSqlError, error.type())}")
136    log.sql.debug(f"database text: {database_text}")
137    log.sql.debug(f"driver text: {driver_text}")
138    log.sql.debug(f"error code: {error_code}")
139
140    known_errors = [
141        SqliteErrorCode.BUSY,
142        SqliteErrorCode.READONLY,
143        SqliteErrorCode.IOERR,
144        SqliteErrorCode.CORRUPT,
145        SqliteErrorCode.FULL,
146        SqliteErrorCode.CANTOPEN,
147        SqliteErrorCode.PROTOCOL,
148        SqliteErrorCode.NOTADB,
149    ]
150
151    # https://github.com/qutebrowser/qutebrowser/issues/4681
152    # If the query we built was too long
153    too_long_err = (
154        error_code == SqliteErrorCode.ERROR and
155        (database_text.startswith("Expression tree is too large") or
156         database_text in ["too many SQL variables",
157                           "LIKE or GLOB pattern too complex"]))
158
159    if error_code in known_errors or too_long_err:
160        raise KnownError(msg, error)
161
162    raise BugError(msg, error)
163
164
165class Database:
166
167    """A wrapper over a QSqlDatabase connection."""
168
169    _USER_VERSION = UserVersion(0, 4)  # The current / newest user version
170
171    def __init__(self, path: str) -> None:
172        if QSqlDatabase.database(path).isValid():
173            raise BugError(f'A connection to the database at "{path}" already exists')
174
175        self._path = path
176        database = QSqlDatabase.addDatabase('QSQLITE', path)
177        if not database.isValid():
178            raise KnownError('Failed to add database. Are sqlite and Qt sqlite '
179                             'support installed?')
180        database.setDatabaseName(path)
181        if not database.open():
182            error = database.lastError()
183            msg = f"Failed to open sqlite database at {path}: {error.text()}"
184            raise_sqlite_error(msg, error)
185
186        version_int = self.query('pragma user_version').run().value()
187        self._user_version = UserVersion.from_int(version_int)
188
189        if self._user_version.major > self._USER_VERSION.major:
190            raise KnownError(
191                "Database is too new for this qutebrowser version (database version "
192                f"{self._user_version}, but {self._USER_VERSION.major}.x is supported)")
193
194        if self.user_version_changed():
195            # Enable write-ahead-logging and reduce disk write frequency
196            # see https://sqlite.org/pragma.html and issues #2930 and #3507
197            #
198            # We might already have done this (without a migration) in earlier versions,
199            # but as those are idempotent, let's make sure we run them once again.
200            self.query("PRAGMA journal_mode=WAL").run()
201            self.query("PRAGMA synchronous=NORMAL").run()
202
203    def qt_database(self) -> QSqlDatabase:
204        """Return the wrapped QSqlDatabase instance."""
205        database = QSqlDatabase.database(self._path, open=True)
206        if not database.isValid():
207            raise BugError('Failed to get connection. Did you close() this Database '
208                           'instance?')
209        return database
210
211    def query(self, querystr: str, forward_only: bool = True) -> 'Query':
212        """Return a Query instance linked to this Database."""
213        return Query(self, querystr, forward_only)
214
215    def table(self, name: str, fields: List[str],
216              constraints: Optional[Dict[str, str]] = None,
217              parent: Optional[QObject] = None) -> 'SqlTable':
218        """Return a SqlTable instance linked to this Database."""
219        return SqlTable(self, name, fields, constraints, parent)
220
221    def user_version_changed(self) -> bool:
222        """Whether the version stored in the database differs from the current one."""
223        return self._user_version != self._USER_VERSION
224
225    def upgrade_user_version(self) -> None:
226        """Upgrade the user version to the latest version.
227
228        This method should be called once all required operations to migrate from one
229        version to another have been run.
230        """
231        log.sql.debug(f"Migrating from version {self._user_version} "
232                      f"to {self._USER_VERSION}")
233        self.query(f'PRAGMA user_version = {self._USER_VERSION.to_int()}').run()
234        self._user_version = self._USER_VERSION
235
236    def close(self) -> None:
237        """Close the SQL connection."""
238        database = self.qt_database()
239        database.close()
240        sip.delete(database)
241        QSqlDatabase.removeDatabase(self._path)
242
243    def transaction(self) -> 'Transaction':
244        """Return a Transaction object linked to this Database."""
245        return Transaction(self)
246
247
248class Transaction(contextlib.AbstractContextManager):  # type: ignore[type-arg]
249
250    """A Database transaction that can be used as a context manager."""
251
252    def __init__(self, database: Database) -> None:
253        self._database = database
254
255    def __enter__(self) -> None:
256        log.sql.debug('Starting a transaction')
257        db = self._database.qt_database()
258        ok = db.transaction()
259        if not ok:
260            error = db.lastError()
261            msg = f'Failed to start a transaction: "{error.text()}"'
262            raise_sqlite_error(msg, error)
263
264    def __exit__(self,
265                 _exc_type: Optional[Type[BaseException]],
266                 exc_val: Optional[BaseException],
267                 _exc_tb: Optional[types.TracebackType]) -> None:
268        db = self._database.qt_database()
269        if exc_val:
270            log.sql.debug('Rolling back a transaction')
271            db.rollback()
272        else:
273            log.sql.debug('Committing a transaction')
274            ok = db.commit()
275            if not ok:
276                error = db.lastError()
277                msg = f'Failed to commit a transaction: "{error.text()}"'
278                raise_sqlite_error(msg, error)
279
280
281class Query:
282
283    """A prepared SQL query."""
284
285    def __init__(self, database: Database, querystr: str,
286                 forward_only: bool = True) -> None:
287        """Prepare a new SQL query.
288
289        Args:
290            database: The Database object on which to operate.
291            querystr: String to prepare query from.
292            forward_only: Optimization for queries that will only step forward.
293                          Must be false for completion queries.
294        """
295        self._database = database
296        self.query = QSqlQuery(database.qt_database())
297
298        log.sql.vdebug(f'Preparing: {querystr}')  # type: ignore[attr-defined]
299        ok = self.query.prepare(querystr)
300        self._check_ok('prepare', ok)
301        self.query.setForwardOnly(forward_only)
302
303    def __iter__(self) -> Iterator[Any]:
304        if not self.query.isActive():
305            raise BugError("Cannot iterate inactive query")
306        rec = self.query.record()
307        fields = [rec.fieldName(i) for i in range(rec.count())]
308        rowtype = collections.namedtuple(  # type: ignore[misc]
309            'ResultRow', fields)
310
311        while self.query.next():
312            rec = self.query.record()
313            yield rowtype(*[rec.value(i) for i in range(rec.count())])
314
315    def _check_ok(self, step: str, ok: bool) -> None:
316        if not ok:
317            query = self.query.lastQuery()
318            error = self.query.lastError()
319            msg = f'Failed to {step} query "{query}": "{error.text()}"'
320            raise_sqlite_error(msg, error)
321
322    def _bind_values(self, values: Mapping[str, Any]) -> Dict[str, Any]:
323        for key, val in values.items():
324            self.query.bindValue(f':{key}', val)
325
326        bound_values = self.bound_values()
327        if None in bound_values.values():
328            raise BugError("Missing bound values!")
329
330        return bound_values
331
332    def run(self, **values: Any) -> 'Query':
333        """Execute the prepared query."""
334        log.sql.debug(self.query.lastQuery())
335
336        bound_values = self._bind_values(values)
337        if bound_values:
338            log.sql.debug(f'    {bound_values}')
339
340        ok = self.query.exec()
341        self._check_ok('exec', ok)
342
343        return self
344
345    def run_batch(self, values: Mapping[str, MutableSequence[Any]]) -> None:
346        """Execute the query in batch mode."""
347        log.sql.debug(f'Running SQL query (batch): "{self.query.lastQuery()}"')
348
349        self._bind_values(values)
350
351        db = self._database.qt_database()
352        ok = db.transaction()
353        self._check_ok('transaction', ok)
354
355        ok = self.query.execBatch()
356        try:
357            self._check_ok('execBatch', ok)
358        except Error:
359            # Not checking the return value here, as we're failing anyways...
360            db.rollback()
361            raise
362
363        ok = db.commit()
364        self._check_ok('commit', ok)
365
366    def value(self) -> Any:
367        """Return the result of a single-value query (e.g. an EXISTS)."""
368        if not self.query.next():
369            raise BugError("No result for single-result query")
370        return self.query.record().value(0)
371
372    def rows_affected(self) -> int:
373        """Return how many rows were affected by a non-SELECT query."""
374        assert not self.query.isSelect(), self
375        assert self.query.isActive(), self
376        rows = self.query.numRowsAffected()
377        assert rows != -1
378        return rows
379
380    def bound_values(self) -> Dict[str, Any]:
381        return self.query.boundValues()
382
383
384class SqlTable(QObject):
385
386    """Interface to a SQL table.
387
388    Attributes:
389        _name: Name of the SQL table this wraps.
390        database: The Database to which this table belongs.
391
392    Signals:
393        changed: Emitted when the table is modified.
394    """
395
396    changed = pyqtSignal()
397    database: Database
398
399    def __init__(self, database: Database, name: str, fields: List[str],
400                 constraints: Optional[Dict[str, str]] = None,
401                 parent: Optional[QObject] = None) -> None:
402        """Wrapper over a table in the SQL database.
403
404        Args:
405            database: The Database to which this table belongs.
406            name: Name of the table.
407            fields: A list of field names.
408            constraints: A dict mapping field names to constraint strings.
409        """
410        super().__init__(parent)
411        self._name = name
412        self.database = database
413        self._create_table(fields, constraints)
414
415    def _create_table(self, fields: List[str], constraints: Optional[Dict[str, str]],
416                      *, force: bool = False) -> None:
417        """Create the table if the database is uninitialized.
418
419        If the table already exists, this does nothing (except with force=True), so it
420        can e.g. be called on every user_version change.
421        """
422        if not self.database.user_version_changed() and not force:
423            return
424
425        constraints = constraints or {}
426        column_defs = [f'{field} {constraints.get(field, "")}'
427                       for field in fields]
428        q = self.database.query(
429            f"CREATE TABLE IF NOT EXISTS {self._name} ({', '.join(column_defs)})"
430        )
431        q.run()
432
433    def create_index(self, name: str, field: str) -> None:
434        """Create an index over this table if the database is uninitialized.
435
436        Args:
437            name: Name of the index, should be unique.
438            field: Name of the field to index.
439        """
440        if not self.database.user_version_changed():
441            return
442
443        q = self.database.query(
444            f"CREATE INDEX IF NOT EXISTS {name} ON {self._name} ({field})"
445        )
446        q.run()
447
448    def __iter__(self) -> Iterator[Any]:
449        """Iterate rows in the table."""
450        q = self.database.query(f"SELECT * FROM {self._name}")
451        q.run()
452        return iter(q)
453
454    def contains_query(self, field: str) -> Query:
455        """Return a prepared query that checks for the existence of an item.
456
457        Args:
458            field: Field to match.
459        """
460        return self.database.query(
461            f"SELECT EXISTS(SELECT * FROM {self._name} WHERE {field} = :val)"
462        )
463
464    def __len__(self) -> int:
465        """Return the count of rows in the table."""
466        q = self.database.query(f"SELECT count(*) FROM {self._name}")
467        q.run()
468        return q.value()
469
470    def __bool__(self) -> bool:
471        """Check whether there's any data in the table."""
472        q = self.database.query(f"SELECT 1 FROM {self._name} LIMIT 1")
473        q.run()
474        return q.query.next()
475
476    def delete(self, field: str, value: Any) -> None:
477        """Remove all rows for which `field` equals `value`.
478
479        Args:
480            field: Field to use as the key.
481            value: Key value to delete.
482
483        Return:
484            The number of rows deleted.
485        """
486        q = self.database.query(f"DELETE FROM {self._name} where {field} = :val")
487        q.run(val=value)
488        if not q.rows_affected():
489            raise KeyError('No row with {field} = "{value}"')
490        self.changed.emit()
491
492    def _insert_query(self, values: Mapping[str, Any], replace: bool) -> Query:
493        params = ', '.join(f':{key}' for key in values)
494        columns = ', '.join(values)
495        verb = "REPLACE" if replace else "INSERT"
496        return self.database.query(
497            f"{verb} INTO {self._name} ({columns}) values({params})"
498        )
499
500    def insert(self, values: Mapping[str, Any], replace: bool = False) -> None:
501        """Append a row to the table.
502
503        Args:
504            values: A dict with a value to insert for each field name.
505            replace: If set, replace existing values.
506        """
507        q = self._insert_query(values, replace)
508        q.run(**values)
509        self.changed.emit()
510
511    def insert_batch(self, values: Mapping[str, MutableSequence[Any]],
512                     replace: bool = False) -> None:
513        """Performantly append multiple rows to the table.
514
515        Args:
516            values: A dict with a list of values to insert for each field name.
517            replace: If true, overwrite rows with a primary key match.
518        """
519        q = self._insert_query(values, replace)
520        q.run_batch(values)
521        self.changed.emit()
522
523    def delete_all(self) -> None:
524        """Remove all rows from the table."""
525        self.database.query(f"DELETE FROM {self._name}").run()
526        self.changed.emit()
527
528    def select(self, sort_by: str, sort_order: str, limit: int = -1) -> Query:
529        """Prepare, run, and return a select statement on this table.
530
531        Args:
532            sort_by: name of column to sort by.
533            sort_order: 'asc' or 'desc'.
534            limit: max number of rows in result, defaults to -1 (unlimited).
535
536        Return: A prepared and executed select query.
537        """
538        q = self.database.query(
539            f"SELECT * FROM {self._name} ORDER BY {sort_by} {sort_order} LIMIT :limit"
540        )
541        q.run(limit=limit)
542        return q
543
544
545def version() -> str:
546    """Return the sqlite version string."""
547    try:
548        with contextlib.closing(Database(':memory:')) as in_memory_db:
549            return in_memory_db.query("select sqlite_version()").run().value()
550    except KnownError as e:
551        return f'UNAVAILABLE ({e})'
552