1# Copyright 2018 John Reese
2# Licensed under the MIT license
3
4"""
5Core implementation of aiosqlite proxies
6"""
7
8import asyncio
9import logging
10import sqlite3
11import sys
12import warnings
13from functools import partial
14from pathlib import Path
15from queue import Empty, Queue
16from threading import Thread
17from typing import (
18    Any,
19    AsyncIterator,
20    Callable,
21    Generator,
22    Iterable,
23    Optional,
24    Type,
25    Union,
26)
27from warnings import warn
28
29from .context import contextmanager
30from .cursor import Cursor
31
32__all__ = ["connect", "Connection", "Cursor"]
33
34LOG = logging.getLogger("aiosqlite")
35
36
37def get_loop(future: asyncio.Future) -> asyncio.AbstractEventLoop:
38    if sys.version_info >= (3, 7):
39        return future.get_loop()
40    else:
41        return future._loop
42
43
44class Connection(Thread):
45    def __init__(
46        self,
47        connector: Callable[[], sqlite3.Connection],
48        iter_chunk_size: int,
49        loop: Optional[asyncio.AbstractEventLoop] = None,
50    ) -> None:
51        super().__init__()
52        self._running = True
53        self._connection: Optional[sqlite3.Connection] = None
54        self._connector = connector
55        self._tx: Queue = Queue()
56        self._iter_chunk_size = iter_chunk_size
57
58        if loop is not None:
59            warn(
60                "aiosqlite.Connection no longer uses the `loop` parameter",
61                DeprecationWarning,
62            )
63
64    @property
65    def _conn(self) -> sqlite3.Connection:
66        if self._connection is None:
67            raise ValueError("no active connection")
68
69        return self._connection
70
71    def _execute_insert(
72        self, sql: str, parameters: Iterable[Any]
73    ) -> Optional[sqlite3.Row]:
74        cursor = self._conn.execute(sql, parameters)
75        cursor.execute("SELECT last_insert_rowid()")
76        return cursor.fetchone()
77
78    def _execute_fetchall(
79        self, sql: str, parameters: Iterable[Any]
80    ) -> Iterable[sqlite3.Row]:
81        cursor = self._conn.execute(sql, parameters)
82        return cursor.fetchall()
83
84    def run(self) -> None:
85        """
86        Execute function calls on a separate thread.
87
88        :meta private:
89        """
90        while True:
91            # Continues running until all queue items are processed,
92            # even after connection is closed (so we can finalize all
93            # futures)
94            try:
95                future, function = self._tx.get(timeout=0.1)
96            except Empty:
97                if self._running:
98                    continue
99                break
100            try:
101                LOG.debug("executing %s", function)
102                result = function()
103                LOG.debug("operation %s completed", function)
104
105                def set_result(fut, result):
106                    if not fut.done():
107                        fut.set_result(result)
108
109                get_loop(future).call_soon_threadsafe(set_result, future, result)
110            except BaseException as e:
111                LOG.debug("returning exception %s", e)
112
113                def set_exception(fut, e):
114                    if not fut.done():
115                        fut.set_exception(e)
116
117                get_loop(future).call_soon_threadsafe(set_exception, future, e)
118
119    async def _execute(self, fn, *args, **kwargs):
120        """Queue a function with the given arguments for execution."""
121        if not self._running or not self._connection:
122            raise ValueError("Connection closed")
123
124        function = partial(fn, *args, **kwargs)
125        future = asyncio.get_event_loop().create_future()
126
127        self._tx.put_nowait((future, function))
128
129        return await future
130
131    async def _connect(self) -> "Connection":
132        """Connect to the actual sqlite database."""
133        if self._connection is None:
134            try:
135                future = asyncio.get_event_loop().create_future()
136                self._tx.put_nowait((future, self._connector))
137                self._connection = await future
138            except Exception:
139                self._running = False
140                self._connection = None
141                raise
142
143        return self
144
145    def __await__(self) -> Generator[Any, None, "Connection"]:
146        self.start()
147        return self._connect().__await__()
148
149    async def __aenter__(self) -> "Connection":
150        return await self
151
152    async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
153        await self.close()
154
155    @contextmanager
156    async def cursor(self) -> Cursor:
157        """Create an aiosqlite cursor wrapping a sqlite3 cursor object."""
158        return Cursor(self, await self._execute(self._conn.cursor))
159
160    async def commit(self) -> None:
161        """Commit the current transaction."""
162        await self._execute(self._conn.commit)
163
164    async def rollback(self) -> None:
165        """Roll back the current transaction."""
166        await self._execute(self._conn.rollback)
167
168    async def close(self) -> None:
169        """Complete queued queries/cursors and close the connection."""
170        try:
171            await self._execute(self._conn.close)
172        except Exception:
173            LOG.info("exception occurred while closing connection")
174            raise
175        finally:
176            self._running = False
177            self._connection = None
178
179    @contextmanager
180    async def execute(self, sql: str, parameters: Iterable[Any] = None) -> Cursor:
181        """Helper to create a cursor and execute the given query."""
182        if parameters is None:
183            parameters = []
184        cursor = await self._execute(self._conn.execute, sql, parameters)
185        return Cursor(self, cursor)
186
187    @contextmanager
188    async def execute_insert(
189        self, sql: str, parameters: Iterable[Any] = None
190    ) -> Optional[sqlite3.Row]:
191        """Helper to insert and get the last_insert_rowid."""
192        if parameters is None:
193            parameters = []
194        return await self._execute(self._execute_insert, sql, parameters)
195
196    @contextmanager
197    async def execute_fetchall(
198        self, sql: str, parameters: Iterable[Any] = None
199    ) -> Iterable[sqlite3.Row]:
200        """Helper to execute a query and return all the data."""
201        if parameters is None:
202            parameters = []
203        return await self._execute(self._execute_fetchall, sql, parameters)
204
205    @contextmanager
206    async def executemany(
207        self, sql: str, parameters: Iterable[Iterable[Any]]
208    ) -> Cursor:
209        """Helper to create a cursor and execute the given multiquery."""
210        cursor = await self._execute(self._conn.executemany, sql, parameters)
211        return Cursor(self, cursor)
212
213    @contextmanager
214    async def executescript(self, sql_script: str) -> Cursor:
215        """Helper to create a cursor and execute a user script."""
216        cursor = await self._execute(self._conn.executescript, sql_script)
217        return Cursor(self, cursor)
218
219    async def interrupt(self) -> None:
220        """Interrupt pending queries."""
221        return self._conn.interrupt()
222
223    async def create_function(
224        self, name: str, num_params: int, func: Callable, deterministic: bool = False
225    ) -> None:
226        """
227        Create user-defined function that can be later used
228        within SQL statements. Must be run within the same thread
229        that query executions take place so instead of executing directly
230        against the connection, we defer this to `run` function.
231
232        In Python 3.8 and above, if *deterministic* is true, the created
233        function is marked as deterministic, which allows SQLite to perform
234        additional optimizations. This flag is supported by SQLite 3.8.3 or
235        higher, ``NotSupportedError`` will be raised if used with older
236        versions.
237        """
238        if sys.version_info >= (3, 8):
239            await self._execute(
240                self._conn.create_function,
241                name,
242                num_params,
243                func,
244                deterministic=deterministic,
245            )
246        else:
247            if deterministic:
248                warnings.warn(
249                    "Deterministic function support is only available on "
250                    'Python 3.8+. Function "{}" will be registered as '
251                    "non-deterministic as per SQLite defaults.".format(name)
252                )
253
254            await self._execute(self._conn.create_function, name, num_params, func)
255
256    @property
257    def in_transaction(self) -> bool:
258        return self._conn.in_transaction
259
260    @property
261    def isolation_level(self) -> str:
262        return self._conn.isolation_level
263
264    @isolation_level.setter
265    def isolation_level(self, value: str) -> None:
266        self._conn.isolation_level = value
267
268    @property
269    def row_factory(self) -> "Optional[Type]":  # py3.5.2 compat (#24)
270        return self._conn.row_factory
271
272    @row_factory.setter
273    def row_factory(self, factory: "Optional[Type]") -> None:  # py3.5.2 compat (#24)
274        self._conn.row_factory = factory
275
276    @property
277    def text_factory(self) -> Type:
278        return self._conn.text_factory
279
280    @text_factory.setter
281    def text_factory(self, factory: Type) -> None:
282        self._conn.text_factory = factory
283
284    @property
285    def total_changes(self) -> int:
286        return self._conn.total_changes
287
288    async def enable_load_extension(self, value: bool) -> None:
289        await self._execute(self._conn.enable_load_extension, value)  # type: ignore
290
291    async def load_extension(self, path: str):
292        await self._execute(self._conn.load_extension, path)  # type: ignore
293
294    async def set_progress_handler(
295        self, handler: Callable[[], Optional[int]], n: int
296    ) -> None:
297        await self._execute(self._conn.set_progress_handler, handler, n)
298
299    async def set_trace_callback(self, handler: Callable) -> None:
300        await self._execute(self._conn.set_trace_callback, handler)
301
302    async def iterdump(self) -> AsyncIterator[str]:
303        """
304        Return an async iterator to dump the database in SQL text format.
305
306        Example::
307
308            async for line in db.iterdump():
309                ...
310
311        """
312        dump_queue: Queue = Queue()
313
314        def dumper():
315            try:
316                for line in self._conn.iterdump():
317                    dump_queue.put_nowait(line)
318                dump_queue.put_nowait(None)
319
320            except Exception:
321                LOG.exception("exception while dumping db")
322                dump_queue.put_nowait(None)
323                raise
324
325        fut = self._execute(dumper)
326        task = asyncio.ensure_future(fut)
327
328        while True:
329            try:
330                line: Optional[str] = dump_queue.get_nowait()
331                if line is None:
332                    break
333                yield line
334
335            except Empty:
336                if task.done():
337                    LOG.warning("iterdump completed unexpectedly")
338                    break
339
340                await asyncio.sleep(0.01)
341
342        await task
343
344    async def backup(
345        self,
346        target: Union["Connection", sqlite3.Connection],
347        *,
348        pages: int = 0,
349        progress: Optional[Callable[[int, int, int], None]] = None,
350        name: str = "main",
351        sleep: float = 0.250
352    ) -> None:
353        """
354        Make a backup of the current database to the target database.
355
356        Takes either a standard sqlite3 or aiosqlite Connection object as the target.
357        """
358        if sys.version_info < (3, 7):
359            raise RuntimeError("backup() method is only available on Python 3.7+")
360
361        if isinstance(target, Connection):
362            target = target._conn
363
364        await self._execute(
365            self._conn.backup,
366            target,
367            pages=pages,
368            progress=progress,
369            name=name,
370            sleep=sleep,
371        )
372
373
374def connect(
375    database: Union[str, Path],
376    *,
377    iter_chunk_size=64,
378    loop: Optional[asyncio.AbstractEventLoop] = None,
379    **kwargs: Any
380) -> Connection:
381    """Create and return a connection proxy to the sqlite database."""
382
383    if loop is not None:
384        warn(
385            "aiosqlite.connect() no longer uses the `loop` parameter",
386            DeprecationWarning,
387        )
388
389    def connector() -> sqlite3.Connection:
390        if isinstance(database, str):
391            loc = database
392        elif isinstance(database, bytes):
393            loc = database.decode("utf-8")
394        else:
395            loc = str(database)
396
397        return sqlite3.connect(loc, **kwargs)
398
399    return Connection(connector, iter_chunk_size)
400