1import logging 2import typing 3import uuid 4 5import aiosqlite 6from sqlalchemy.dialects.sqlite import pysqlite 7from sqlalchemy.engine.cursor import CursorResultMetaData 8from sqlalchemy.engine.interfaces import Dialect, ExecutionContext 9from sqlalchemy.engine.row import Row 10from sqlalchemy.sql import ClauseElement 11from sqlalchemy.sql.ddl import DDLElement 12 13from databases.core import LOG_EXTRA, DatabaseURL 14from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend 15 16logger = logging.getLogger("databases") 17 18 19class SQLiteBackend(DatabaseBackend): 20 def __init__( 21 self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any 22 ) -> None: 23 self._database_url = DatabaseURL(database_url) 24 self._options = options 25 self._dialect = pysqlite.dialect(paramstyle="qmark") 26 # aiosqlite does not support decimals 27 self._dialect.supports_native_decimal = False 28 self._pool = SQLitePool(self._database_url, **self._options) 29 30 async def connect(self) -> None: 31 pass 32 # assert self._pool is None, "DatabaseBackend is already running" 33 # self._pool = await aiomysql.create_pool( 34 # host=self._database_url.hostname, 35 # port=self._database_url.port or 3306, 36 # user=self._database_url.username or getpass.getuser(), 37 # password=self._database_url.password, 38 # db=self._database_url.database, 39 # autocommit=True, 40 # ) 41 42 async def disconnect(self) -> None: 43 pass 44 # assert self._pool is not None, "DatabaseBackend is not running" 45 # self._pool.close() 46 # await self._pool.wait_closed() 47 # self._pool = None 48 49 def connection(self) -> "SQLiteConnection": 50 return SQLiteConnection(self._pool, self._dialect) 51 52 53class SQLitePool: 54 def __init__(self, url: DatabaseURL, **options: typing.Any) -> None: 55 self._url = url 56 self._options = options 57 58 async def acquire(self) -> aiosqlite.Connection: 59 connection = aiosqlite.connect( 60 database=self._url.database, isolation_level=None, **self._options 61 ) 62 await connection.__aenter__() 63 return connection 64 65 async def release(self, connection: aiosqlite.Connection) -> None: 66 await connection.__aexit__(None, None, None) 67 68 69class CompilationContext: 70 def __init__(self, context: ExecutionContext): 71 self.context = context 72 73 74class SQLiteConnection(ConnectionBackend): 75 def __init__(self, pool: SQLitePool, dialect: Dialect): 76 self._pool = pool 77 self._dialect = dialect 78 self._connection = None # type: typing.Optional[aiosqlite.Connection] 79 80 async def acquire(self) -> None: 81 assert self._connection is None, "Connection is already acquired" 82 self._connection = await self._pool.acquire() 83 84 async def release(self) -> None: 85 assert self._connection is not None, "Connection is not acquired" 86 await self._pool.release(self._connection) 87 self._connection = None 88 89 async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]: 90 assert self._connection is not None, "Connection is not acquired" 91 query_str, args, context = self._compile(query) 92 93 async with self._connection.execute(query_str, args) as cursor: 94 rows = await cursor.fetchall() 95 metadata = CursorResultMetaData(context, cursor.description) 96 return [ 97 Row( 98 metadata, 99 metadata._processors, 100 metadata._keymap, 101 Row._default_key_style, 102 row, 103 ) 104 for row in rows 105 ] 106 107 async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mapping]: 108 assert self._connection is not None, "Connection is not acquired" 109 query_str, args, context = self._compile(query) 110 111 async with self._connection.execute(query_str, args) as cursor: 112 row = await cursor.fetchone() 113 if row is None: 114 return None 115 metadata = CursorResultMetaData(context, cursor.description) 116 return Row( 117 metadata, 118 metadata._processors, 119 metadata._keymap, 120 Row._default_key_style, 121 row, 122 ) 123 124 async def execute(self, query: ClauseElement) -> typing.Any: 125 assert self._connection is not None, "Connection is not acquired" 126 query_str, args, context = self._compile(query) 127 async with self._connection.cursor() as cursor: 128 await cursor.execute(query_str, args) 129 if cursor.lastrowid == 0: 130 return cursor.rowcount 131 return cursor.lastrowid 132 133 async def execute_many(self, queries: typing.List[ClauseElement]) -> None: 134 assert self._connection is not None, "Connection is not acquired" 135 for single_query in queries: 136 await self.execute(single_query) 137 138 async def iterate( 139 self, query: ClauseElement 140 ) -> typing.AsyncGenerator[typing.Any, None]: 141 assert self._connection is not None, "Connection is not acquired" 142 query_str, args, context = self._compile(query) 143 async with self._connection.execute(query_str, args) as cursor: 144 metadata = CursorResultMetaData(context, cursor.description) 145 async for row in cursor: 146 yield Row( 147 metadata, 148 metadata._processors, 149 metadata._keymap, 150 Row._default_key_style, 151 row, 152 ) 153 154 def transaction(self) -> TransactionBackend: 155 return SQLiteTransaction(self) 156 157 def _compile( 158 self, query: ClauseElement 159 ) -> typing.Tuple[str, list, CompilationContext]: 160 compiled = query.compile( 161 dialect=self._dialect, compile_kwargs={"render_postcompile": True} 162 ) 163 164 execution_context = self._dialect.execution_ctx_cls() 165 execution_context.dialect = self._dialect 166 167 args = [] 168 169 if not isinstance(query, DDLElement): 170 for key, raw_val in compiled.construct_params().items(): 171 if key in compiled._bind_processors: 172 val = compiled._bind_processors[key](raw_val) 173 else: 174 val = raw_val 175 args.append(val) 176 177 execution_context.result_column_struct = ( 178 compiled._result_columns, 179 compiled._ordered_columns, 180 compiled._textual_ordered_columns, 181 compiled._loose_column_name_matching, 182 ) 183 184 query_message = compiled.string.replace(" \n", " ").replace("\n", " ") 185 logger.debug( 186 "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA 187 ) 188 return compiled.string, args, CompilationContext(execution_context) 189 190 @property 191 def raw_connection(self) -> aiosqlite.core.Connection: 192 assert self._connection is not None, "Connection is not acquired" 193 return self._connection 194 195 196class SQLiteTransaction(TransactionBackend): 197 def __init__(self, connection: SQLiteConnection): 198 self._connection = connection 199 self._is_root = False 200 self._savepoint_name = "" 201 202 async def start( 203 self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any] 204 ) -> None: 205 assert self._connection._connection is not None, "Connection is not acquired" 206 self._is_root = is_root 207 if self._is_root: 208 async with self._connection._connection.execute("BEGIN") as cursor: 209 await cursor.close() 210 else: 211 id = str(uuid.uuid4()).replace("-", "_") 212 self._savepoint_name = f"STARLETTE_SAVEPOINT_{id}" 213 async with self._connection._connection.execute( 214 f"SAVEPOINT {self._savepoint_name}" 215 ) as cursor: 216 await cursor.close() 217 218 async def commit(self) -> None: 219 assert self._connection._connection is not None, "Connection is not acquired" 220 if self._is_root: 221 async with self._connection._connection.execute("COMMIT") as cursor: 222 await cursor.close() 223 else: 224 async with self._connection._connection.execute( 225 f"RELEASE SAVEPOINT {self._savepoint_name}" 226 ) as cursor: 227 await cursor.close() 228 229 async def rollback(self) -> None: 230 assert self._connection._connection is not None, "Connection is not acquired" 231 if self._is_root: 232 async with self._connection._connection.execute("ROLLBACK") as cursor: 233 await cursor.close() 234 else: 235 async with self._connection._connection.execute( 236 f"ROLLBACK TO SAVEPOINT {self._savepoint_name}" 237 ) as cursor: 238 await cursor.close() 239