1import logging
2import typing
3import uuid
4
5import aiosqlite
6from sqlalchemy.dialects.sqlite import pysqlite
7from sqlalchemy.engine.cursor import CursorResultMetaData
8from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
9from sqlalchemy.engine.row import Row
10from sqlalchemy.sql import ClauseElement
11from sqlalchemy.sql.ddl import DDLElement
12
13from databases.core import LOG_EXTRA, DatabaseURL
14from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend
15
16logger = logging.getLogger("databases")
17
18
19class SQLiteBackend(DatabaseBackend):
20    def __init__(
21        self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any
22    ) -> None:
23        self._database_url = DatabaseURL(database_url)
24        self._options = options
25        self._dialect = pysqlite.dialect(paramstyle="qmark")
26        # aiosqlite does not support decimals
27        self._dialect.supports_native_decimal = False
28        self._pool = SQLitePool(self._database_url, **self._options)
29
30    async def connect(self) -> None:
31        pass
32        # assert self._pool is None, "DatabaseBackend is already running"
33        # self._pool = await aiomysql.create_pool(
34        #     host=self._database_url.hostname,
35        #     port=self._database_url.port or 3306,
36        #     user=self._database_url.username or getpass.getuser(),
37        #     password=self._database_url.password,
38        #     db=self._database_url.database,
39        #     autocommit=True,
40        # )
41
42    async def disconnect(self) -> None:
43        pass
44        # assert self._pool is not None, "DatabaseBackend is not running"
45        # self._pool.close()
46        # await self._pool.wait_closed()
47        # self._pool = None
48
49    def connection(self) -> "SQLiteConnection":
50        return SQLiteConnection(self._pool, self._dialect)
51
52
53class SQLitePool:
54    def __init__(self, url: DatabaseURL, **options: typing.Any) -> None:
55        self._url = url
56        self._options = options
57
58    async def acquire(self) -> aiosqlite.Connection:
59        connection = aiosqlite.connect(
60            database=self._url.database, isolation_level=None, **self._options
61        )
62        await connection.__aenter__()
63        return connection
64
65    async def release(self, connection: aiosqlite.Connection) -> None:
66        await connection.__aexit__(None, None, None)
67
68
69class CompilationContext:
70    def __init__(self, context: ExecutionContext):
71        self.context = context
72
73
74class SQLiteConnection(ConnectionBackend):
75    def __init__(self, pool: SQLitePool, dialect: Dialect):
76        self._pool = pool
77        self._dialect = dialect
78        self._connection = None  # type: typing.Optional[aiosqlite.Connection]
79
80    async def acquire(self) -> None:
81        assert self._connection is None, "Connection is already acquired"
82        self._connection = await self._pool.acquire()
83
84    async def release(self) -> None:
85        assert self._connection is not None, "Connection is not acquired"
86        await self._pool.release(self._connection)
87        self._connection = None
88
89    async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]:
90        assert self._connection is not None, "Connection is not acquired"
91        query_str, args, context = self._compile(query)
92
93        async with self._connection.execute(query_str, args) as cursor:
94            rows = await cursor.fetchall()
95            metadata = CursorResultMetaData(context, cursor.description)
96            return [
97                Row(
98                    metadata,
99                    metadata._processors,
100                    metadata._keymap,
101                    Row._default_key_style,
102                    row,
103                )
104                for row in rows
105            ]
106
107    async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mapping]:
108        assert self._connection is not None, "Connection is not acquired"
109        query_str, args, context = self._compile(query)
110
111        async with self._connection.execute(query_str, args) as cursor:
112            row = await cursor.fetchone()
113            if row is None:
114                return None
115            metadata = CursorResultMetaData(context, cursor.description)
116            return Row(
117                metadata,
118                metadata._processors,
119                metadata._keymap,
120                Row._default_key_style,
121                row,
122            )
123
124    async def execute(self, query: ClauseElement) -> typing.Any:
125        assert self._connection is not None, "Connection is not acquired"
126        query_str, args, context = self._compile(query)
127        async with self._connection.cursor() as cursor:
128            await cursor.execute(query_str, args)
129            if cursor.lastrowid == 0:
130                return cursor.rowcount
131            return cursor.lastrowid
132
133    async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
134        assert self._connection is not None, "Connection is not acquired"
135        for single_query in queries:
136            await self.execute(single_query)
137
138    async def iterate(
139        self, query: ClauseElement
140    ) -> typing.AsyncGenerator[typing.Any, None]:
141        assert self._connection is not None, "Connection is not acquired"
142        query_str, args, context = self._compile(query)
143        async with self._connection.execute(query_str, args) as cursor:
144            metadata = CursorResultMetaData(context, cursor.description)
145            async for row in cursor:
146                yield Row(
147                    metadata,
148                    metadata._processors,
149                    metadata._keymap,
150                    Row._default_key_style,
151                    row,
152                )
153
154    def transaction(self) -> TransactionBackend:
155        return SQLiteTransaction(self)
156
157    def _compile(
158        self, query: ClauseElement
159    ) -> typing.Tuple[str, list, CompilationContext]:
160        compiled = query.compile(
161            dialect=self._dialect, compile_kwargs={"render_postcompile": True}
162        )
163
164        execution_context = self._dialect.execution_ctx_cls()
165        execution_context.dialect = self._dialect
166
167        args = []
168
169        if not isinstance(query, DDLElement):
170            for key, raw_val in compiled.construct_params().items():
171                if key in compiled._bind_processors:
172                    val = compiled._bind_processors[key](raw_val)
173                else:
174                    val = raw_val
175                args.append(val)
176
177            execution_context.result_column_struct = (
178                compiled._result_columns,
179                compiled._ordered_columns,
180                compiled._textual_ordered_columns,
181                compiled._loose_column_name_matching,
182            )
183
184        query_message = compiled.string.replace(" \n", " ").replace("\n", " ")
185        logger.debug(
186            "Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA
187        )
188        return compiled.string, args, CompilationContext(execution_context)
189
190    @property
191    def raw_connection(self) -> aiosqlite.core.Connection:
192        assert self._connection is not None, "Connection is not acquired"
193        return self._connection
194
195
196class SQLiteTransaction(TransactionBackend):
197    def __init__(self, connection: SQLiteConnection):
198        self._connection = connection
199        self._is_root = False
200        self._savepoint_name = ""
201
202    async def start(
203        self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]
204    ) -> None:
205        assert self._connection._connection is not None, "Connection is not acquired"
206        self._is_root = is_root
207        if self._is_root:
208            async with self._connection._connection.execute("BEGIN") as cursor:
209                await cursor.close()
210        else:
211            id = str(uuid.uuid4()).replace("-", "_")
212            self._savepoint_name = f"STARLETTE_SAVEPOINT_{id}"
213            async with self._connection._connection.execute(
214                f"SAVEPOINT {self._savepoint_name}"
215            ) as cursor:
216                await cursor.close()
217
218    async def commit(self) -> None:
219        assert self._connection._connection is not None, "Connection is not acquired"
220        if self._is_root:
221            async with self._connection._connection.execute("COMMIT") as cursor:
222                await cursor.close()
223        else:
224            async with self._connection._connection.execute(
225                f"RELEASE SAVEPOINT {self._savepoint_name}"
226            ) as cursor:
227                await cursor.close()
228
229    async def rollback(self) -> None:
230        assert self._connection._connection is not None, "Connection is not acquired"
231        if self._is_root:
232            async with self._connection._connection.execute("ROLLBACK") as cursor:
233                await cursor.close()
234        else:
235            async with self._connection._connection.execute(
236                f"ROLLBACK TO SAVEPOINT {self._savepoint_name}"
237            ) as cursor:
238                await cursor.close()
239