1# Copyright 2018 John Reese 2# Licensed under the MIT license 3 4""" 5Core implementation of aiosqlite proxies 6""" 7 8import asyncio 9import logging 10import sqlite3 11import sys 12import warnings 13from functools import partial 14from pathlib import Path 15from queue import Empty, Queue 16from threading import Thread 17from typing import ( 18 Any, 19 AsyncIterator, 20 Callable, 21 Generator, 22 Iterable, 23 Optional, 24 Type, 25 Union, 26) 27from warnings import warn 28 29from .context import contextmanager 30from .cursor import Cursor 31 32__all__ = ["connect", "Connection", "Cursor"] 33 34LOG = logging.getLogger("aiosqlite") 35 36 37def get_loop(future: asyncio.Future) -> asyncio.AbstractEventLoop: 38 if sys.version_info >= (3, 7): 39 return future.get_loop() 40 else: 41 return future._loop 42 43 44class Connection(Thread): 45 def __init__( 46 self, 47 connector: Callable[[], sqlite3.Connection], 48 iter_chunk_size: int, 49 loop: Optional[asyncio.AbstractEventLoop] = None, 50 ) -> None: 51 super().__init__() 52 self._running = True 53 self._connection: Optional[sqlite3.Connection] = None 54 self._connector = connector 55 self._tx: Queue = Queue() 56 self._iter_chunk_size = iter_chunk_size 57 58 if loop is not None: 59 warn( 60 "aiosqlite.Connection no longer uses the `loop` parameter", 61 DeprecationWarning, 62 ) 63 64 @property 65 def _conn(self) -> sqlite3.Connection: 66 if self._connection is None: 67 raise ValueError("no active connection") 68 69 return self._connection 70 71 def _execute_insert( 72 self, sql: str, parameters: Iterable[Any] 73 ) -> Optional[sqlite3.Row]: 74 cursor = self._conn.execute(sql, parameters) 75 cursor.execute("SELECT last_insert_rowid()") 76 return cursor.fetchone() 77 78 def _execute_fetchall( 79 self, sql: str, parameters: Iterable[Any] 80 ) -> Iterable[sqlite3.Row]: 81 cursor = self._conn.execute(sql, parameters) 82 return cursor.fetchall() 83 84 def run(self) -> None: 85 """ 86 Execute function calls on a separate thread. 87 88 :meta private: 89 """ 90 while True: 91 # Continues running until all queue items are processed, 92 # even after connection is closed (so we can finalize all 93 # futures) 94 try: 95 future, function = self._tx.get(timeout=0.1) 96 except Empty: 97 if self._running: 98 continue 99 break 100 try: 101 LOG.debug("executing %s", function) 102 result = function() 103 LOG.debug("operation %s completed", function) 104 105 def set_result(fut, result): 106 if not fut.done(): 107 fut.set_result(result) 108 109 get_loop(future).call_soon_threadsafe(set_result, future, result) 110 except BaseException as e: 111 LOG.debug("returning exception %s", e) 112 113 def set_exception(fut, e): 114 if not fut.done(): 115 fut.set_exception(e) 116 117 get_loop(future).call_soon_threadsafe(set_exception, future, e) 118 119 async def _execute(self, fn, *args, **kwargs): 120 """Queue a function with the given arguments for execution.""" 121 if not self._running or not self._connection: 122 raise ValueError("Connection closed") 123 124 function = partial(fn, *args, **kwargs) 125 future = asyncio.get_event_loop().create_future() 126 127 self._tx.put_nowait((future, function)) 128 129 return await future 130 131 async def _connect(self) -> "Connection": 132 """Connect to the actual sqlite database.""" 133 if self._connection is None: 134 try: 135 future = asyncio.get_event_loop().create_future() 136 self._tx.put_nowait((future, self._connector)) 137 self._connection = await future 138 except Exception: 139 self._running = False 140 self._connection = None 141 raise 142 143 return self 144 145 def __await__(self) -> Generator[Any, None, "Connection"]: 146 self.start() 147 return self._connect().__await__() 148 149 async def __aenter__(self) -> "Connection": 150 return await self 151 152 async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: 153 await self.close() 154 155 @contextmanager 156 async def cursor(self) -> Cursor: 157 """Create an aiosqlite cursor wrapping a sqlite3 cursor object.""" 158 return Cursor(self, await self._execute(self._conn.cursor)) 159 160 async def commit(self) -> None: 161 """Commit the current transaction.""" 162 await self._execute(self._conn.commit) 163 164 async def rollback(self) -> None: 165 """Roll back the current transaction.""" 166 await self._execute(self._conn.rollback) 167 168 async def close(self) -> None: 169 """Complete queued queries/cursors and close the connection.""" 170 try: 171 await self._execute(self._conn.close) 172 except Exception: 173 LOG.info("exception occurred while closing connection") 174 raise 175 finally: 176 self._running = False 177 self._connection = None 178 179 @contextmanager 180 async def execute(self, sql: str, parameters: Iterable[Any] = None) -> Cursor: 181 """Helper to create a cursor and execute the given query.""" 182 if parameters is None: 183 parameters = [] 184 cursor = await self._execute(self._conn.execute, sql, parameters) 185 return Cursor(self, cursor) 186 187 @contextmanager 188 async def execute_insert( 189 self, sql: str, parameters: Iterable[Any] = None 190 ) -> Optional[sqlite3.Row]: 191 """Helper to insert and get the last_insert_rowid.""" 192 if parameters is None: 193 parameters = [] 194 return await self._execute(self._execute_insert, sql, parameters) 195 196 @contextmanager 197 async def execute_fetchall( 198 self, sql: str, parameters: Iterable[Any] = None 199 ) -> Iterable[sqlite3.Row]: 200 """Helper to execute a query and return all the data.""" 201 if parameters is None: 202 parameters = [] 203 return await self._execute(self._execute_fetchall, sql, parameters) 204 205 @contextmanager 206 async def executemany( 207 self, sql: str, parameters: Iterable[Iterable[Any]] 208 ) -> Cursor: 209 """Helper to create a cursor and execute the given multiquery.""" 210 cursor = await self._execute(self._conn.executemany, sql, parameters) 211 return Cursor(self, cursor) 212 213 @contextmanager 214 async def executescript(self, sql_script: str) -> Cursor: 215 """Helper to create a cursor and execute a user script.""" 216 cursor = await self._execute(self._conn.executescript, sql_script) 217 return Cursor(self, cursor) 218 219 async def interrupt(self) -> None: 220 """Interrupt pending queries.""" 221 return self._conn.interrupt() 222 223 async def create_function( 224 self, name: str, num_params: int, func: Callable, deterministic: bool = False 225 ) -> None: 226 """ 227 Create user-defined function that can be later used 228 within SQL statements. Must be run within the same thread 229 that query executions take place so instead of executing directly 230 against the connection, we defer this to `run` function. 231 232 In Python 3.8 and above, if *deterministic* is true, the created 233 function is marked as deterministic, which allows SQLite to perform 234 additional optimizations. This flag is supported by SQLite 3.8.3 or 235 higher, ``NotSupportedError`` will be raised if used with older 236 versions. 237 """ 238 if sys.version_info >= (3, 8): 239 await self._execute( 240 self._conn.create_function, 241 name, 242 num_params, 243 func, 244 deterministic=deterministic, 245 ) 246 else: 247 if deterministic: 248 warnings.warn( 249 "Deterministic function support is only available on " 250 'Python 3.8+. Function "{}" will be registered as ' 251 "non-deterministic as per SQLite defaults.".format(name) 252 ) 253 254 await self._execute(self._conn.create_function, name, num_params, func) 255 256 @property 257 def in_transaction(self) -> bool: 258 return self._conn.in_transaction 259 260 @property 261 def isolation_level(self) -> str: 262 return self._conn.isolation_level 263 264 @isolation_level.setter 265 def isolation_level(self, value: str) -> None: 266 self._conn.isolation_level = value 267 268 @property 269 def row_factory(self) -> "Optional[Type]": # py3.5.2 compat (#24) 270 return self._conn.row_factory 271 272 @row_factory.setter 273 def row_factory(self, factory: "Optional[Type]") -> None: # py3.5.2 compat (#24) 274 self._conn.row_factory = factory 275 276 @property 277 def text_factory(self) -> Type: 278 return self._conn.text_factory 279 280 @text_factory.setter 281 def text_factory(self, factory: Type) -> None: 282 self._conn.text_factory = factory 283 284 @property 285 def total_changes(self) -> int: 286 return self._conn.total_changes 287 288 async def enable_load_extension(self, value: bool) -> None: 289 await self._execute(self._conn.enable_load_extension, value) # type: ignore 290 291 async def load_extension(self, path: str): 292 await self._execute(self._conn.load_extension, path) # type: ignore 293 294 async def set_progress_handler( 295 self, handler: Callable[[], Optional[int]], n: int 296 ) -> None: 297 await self._execute(self._conn.set_progress_handler, handler, n) 298 299 async def set_trace_callback(self, handler: Callable) -> None: 300 await self._execute(self._conn.set_trace_callback, handler) 301 302 async def iterdump(self) -> AsyncIterator[str]: 303 """ 304 Return an async iterator to dump the database in SQL text format. 305 306 Example:: 307 308 async for line in db.iterdump(): 309 ... 310 311 """ 312 dump_queue: Queue = Queue() 313 314 def dumper(): 315 try: 316 for line in self._conn.iterdump(): 317 dump_queue.put_nowait(line) 318 dump_queue.put_nowait(None) 319 320 except Exception: 321 LOG.exception("exception while dumping db") 322 dump_queue.put_nowait(None) 323 raise 324 325 fut = self._execute(dumper) 326 task = asyncio.ensure_future(fut) 327 328 while True: 329 try: 330 line: Optional[str] = dump_queue.get_nowait() 331 if line is None: 332 break 333 yield line 334 335 except Empty: 336 if task.done(): 337 LOG.warning("iterdump completed unexpectedly") 338 break 339 340 await asyncio.sleep(0.01) 341 342 await task 343 344 async def backup( 345 self, 346 target: Union["Connection", sqlite3.Connection], 347 *, 348 pages: int = 0, 349 progress: Optional[Callable[[int, int, int], None]] = None, 350 name: str = "main", 351 sleep: float = 0.250 352 ) -> None: 353 """ 354 Make a backup of the current database to the target database. 355 356 Takes either a standard sqlite3 or aiosqlite Connection object as the target. 357 """ 358 if sys.version_info < (3, 7): 359 raise RuntimeError("backup() method is only available on Python 3.7+") 360 361 if isinstance(target, Connection): 362 target = target._conn 363 364 await self._execute( 365 self._conn.backup, 366 target, 367 pages=pages, 368 progress=progress, 369 name=name, 370 sleep=sleep, 371 ) 372 373 374def connect( 375 database: Union[str, Path], 376 *, 377 iter_chunk_size=64, 378 loop: Optional[asyncio.AbstractEventLoop] = None, 379 **kwargs: Any 380) -> Connection: 381 """Create and return a connection proxy to the sqlite database.""" 382 383 if loop is not None: 384 warn( 385 "aiosqlite.connect() no longer uses the `loop` parameter", 386 DeprecationWarning, 387 ) 388 389 def connector() -> sqlite3.Connection: 390 if isinstance(database, str): 391 loc = database 392 elif isinstance(database, bytes): 393 loc = database.decode("utf-8") 394 else: 395 loc = str(database) 396 397 return sqlite3.connect(loc, **kwargs) 398 399 return Connection(connector, iter_chunk_size) 400