1# Copyright 2014-2016 OpenMarket Ltd
2# Copyright 2017-2018 New Vector Ltd
3# Copyright 2019 The Matrix.org Foundation C.I.C.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16import inspect
17import logging
18import time
19import types
20from collections import defaultdict
21from sys import intern
22from time import monotonic as monotonic_time
23from typing import (
24    TYPE_CHECKING,
25    Any,
26    Callable,
27    Collection,
28    Dict,
29    Iterable,
30    Iterator,
31    List,
32    Optional,
33    Tuple,
34    TypeVar,
35    cast,
36    overload,
37)
38
39import attr
40from prometheus_client import Histogram
41from typing_extensions import Literal
42
43from twisted.enterprise import adbapi
44
45from synapse.api.errors import StoreError
46from synapse.config.database import DatabaseConnectionConfig
47from synapse.logging import opentracing
48from synapse.logging.context import (
49    LoggingContext,
50    current_context,
51    make_deferred_yieldable,
52)
53from synapse.metrics import register_threadpool
54from synapse.metrics.background_process_metrics import run_as_background_process
55from synapse.storage.background_updates import BackgroundUpdater
56from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
57from synapse.storage.types import Connection, Cursor
58from synapse.util.iterutils import batch_iter
59
60if TYPE_CHECKING:
61    from synapse.server import HomeServer
62
63# python 3 does not have a maximum int value
64MAX_TXN_ID = 2 ** 63 - 1
65
66logger = logging.getLogger(__name__)
67
68sql_logger = logging.getLogger("synapse.storage.SQL")
69transaction_logger = logging.getLogger("synapse.storage.txn")
70perf_logger = logging.getLogger("synapse.storage.TIME")
71
72sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
73
74sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
75sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
76
77
78# Unique indexes which have been added in background updates. Maps from table name
79# to the name of the background update which added the unique index to that table.
80#
81# This is used by the upsert logic to figure out which tables are safe to do a proper
82# UPSERT on: until the relevant background update has completed, we
83# have to emulate an upsert by locking the table.
84#
85UNIQUE_INDEX_BACKGROUND_UPDATES = {
86    "user_ips": "user_ips_device_unique_index",
87    "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
88    "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
89    "event_search": "event_search_event_id_idx",
90}
91
92
93def make_pool(
94    reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
95) -> adbapi.ConnectionPool:
96    """Get the connection pool for the database."""
97
98    # By default enable `cp_reconnect`. We need to fiddle with db_args in case
99    # someone has explicitly set `cp_reconnect`.
100    db_args = dict(db_config.config.get("args", {}))
101    db_args.setdefault("cp_reconnect", True)
102
103    def _on_new_connection(conn):
104        # Ensure we have a logging context so we can correctly track queries,
105        # etc.
106        with LoggingContext("db.on_new_connection"):
107            engine.on_new_connection(
108                LoggingDatabaseConnection(conn, engine, "on_new_connection")
109            )
110
111    connection_pool = adbapi.ConnectionPool(
112        db_config.config["name"],
113        cp_reactor=reactor,
114        cp_openfun=_on_new_connection,
115        **db_args,
116    )
117
118    register_threadpool(f"database-{db_config.name}", connection_pool.threadpool)
119
120    return connection_pool
121
122
123def make_conn(
124    db_config: DatabaseConnectionConfig,
125    engine: BaseDatabaseEngine,
126    default_txn_name: str,
127) -> "LoggingDatabaseConnection":
128    """Make a new connection to the database and return it.
129
130    Returns:
131        Connection
132    """
133
134    db_params = {
135        k: v
136        for k, v in db_config.config.get("args", {}).items()
137        if not k.startswith("cp_")
138    }
139    native_db_conn = engine.module.connect(**db_params)
140    db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name)
141
142    engine.on_new_connection(db_conn)
143    return db_conn
144
145
146@attr.s(slots=True)
147class LoggingDatabaseConnection:
148    """A wrapper around a database connection that returns `LoggingTransaction`
149    as its cursor class.
150
151    This is mainly used on startup to ensure that queries get logged correctly
152    """
153
154    conn = attr.ib(type=Connection)
155    engine = attr.ib(type=BaseDatabaseEngine)
156    default_txn_name = attr.ib(type=str)
157
158    def cursor(
159        self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
160    ) -> "LoggingTransaction":
161        if not txn_name:
162            txn_name = self.default_txn_name
163
164        return LoggingTransaction(
165            self.conn.cursor(),
166            name=txn_name,
167            database_engine=self.engine,
168            after_callbacks=after_callbacks,
169            exception_callbacks=exception_callbacks,
170        )
171
172    def close(self) -> None:
173        self.conn.close()
174
175    def commit(self) -> None:
176        self.conn.commit()
177
178    def rollback(self) -> None:
179        self.conn.rollback()
180
181    def __enter__(self) -> "LoggingDatabaseConnection":
182        self.conn.__enter__()
183        return self
184
185    def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
186        return self.conn.__exit__(exc_type, exc_value, traceback)
187
188    # Proxy through any unknown lookups to the DB conn class.
189    def __getattr__(self, name):
190        return getattr(self.conn, name)
191
192
193# The type of entry which goes on our after_callbacks and exception_callbacks lists.
194_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]]
195
196
197R = TypeVar("R")
198
199
200class LoggingTransaction:
201    """An object that almost-transparently proxies for the 'txn' object
202    passed to the constructor. Adds logging and metrics to the .execute()
203    method.
204
205    Args:
206        txn: The database transaction object to wrap.
207        name: The name of this transactions for logging.
208        database_engine
209        after_callbacks: A list that callbacks will be appended to
210            that have been added by `call_after` which should be run on
211            successful completion of the transaction. None indicates that no
212            callbacks should be allowed to be scheduled to run.
213        exception_callbacks: A list that callbacks will be appended
214            to that have been added by `call_on_exception` which should be run
215            if transaction ends with an error. None indicates that no callbacks
216            should be allowed to be scheduled to run.
217    """
218
219    __slots__ = [
220        "txn",
221        "name",
222        "database_engine",
223        "after_callbacks",
224        "exception_callbacks",
225    ]
226
227    def __init__(
228        self,
229        txn: Cursor,
230        name: str,
231        database_engine: BaseDatabaseEngine,
232        after_callbacks: Optional[List[_CallbackListEntry]] = None,
233        exception_callbacks: Optional[List[_CallbackListEntry]] = None,
234    ):
235        self.txn = txn
236        self.name = name
237        self.database_engine = database_engine
238        self.after_callbacks = after_callbacks
239        self.exception_callbacks = exception_callbacks
240
241    def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any):
242        """Call the given callback on the main twisted thread after the
243        transaction has finished. Used to invalidate the caches on the
244        correct thread.
245        """
246        # if self.after_callbacks is None, that means that whatever constructed the
247        # LoggingTransaction isn't expecting there to be any callbacks; assert that
248        # is not the case.
249        assert self.after_callbacks is not None
250        self.after_callbacks.append((callback, args, kwargs))
251
252    def call_on_exception(
253        self, callback: Callable[..., object], *args: Any, **kwargs: Any
254    ):
255        # if self.exception_callbacks is None, that means that whatever constructed the
256        # LoggingTransaction isn't expecting there to be any callbacks; assert that
257        # is not the case.
258        assert self.exception_callbacks is not None
259        self.exception_callbacks.append((callback, args, kwargs))
260
261    def fetchone(self) -> Optional[Tuple]:
262        return self.txn.fetchone()
263
264    def fetchmany(self, size: Optional[int] = None) -> List[Tuple]:
265        return self.txn.fetchmany(size=size)
266
267    def fetchall(self) -> List[Tuple]:
268        return self.txn.fetchall()
269
270    def __iter__(self) -> Iterator[Tuple]:
271        return self.txn.__iter__()
272
273    @property
274    def rowcount(self) -> int:
275        return self.txn.rowcount
276
277    @property
278    def description(self) -> Any:
279        return self.txn.description
280
281    def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
282        """Similar to `executemany`, except `txn.rowcount` will not be correct
283        afterwards.
284
285        More efficient than `executemany` on PostgreSQL
286        """
287
288        if isinstance(self.database_engine, PostgresEngine):
289            from psycopg2.extras import execute_batch  # type: ignore
290
291            self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
292        else:
293            self.executemany(sql, args)
294
295    def execute_values(self, sql: str, *args: Any, fetch: bool = True) -> List[Tuple]:
296        """Corresponds to psycopg2.extras.execute_values. Only available when
297        using postgres.
298
299        The `fetch` parameter must be set to False if the query does not return
300        rows (e.g. INSERTs).
301        """
302        assert isinstance(self.database_engine, PostgresEngine)
303        from psycopg2.extras import execute_values  # type: ignore
304
305        return self._do_execute(
306            lambda *x: execute_values(self.txn, *x, fetch=fetch), sql, *args
307        )
308
309    def execute(self, sql: str, *args: Any) -> None:
310        self._do_execute(self.txn.execute, sql, *args)
311
312    def executemany(self, sql: str, *args: Any) -> None:
313        self._do_execute(self.txn.executemany, sql, *args)
314
315    def _make_sql_one_line(self, sql: str) -> str:
316        "Strip newlines out of SQL so that the loggers in the DB are on one line"
317        return " ".join(line.strip() for line in sql.splitlines() if line.strip())
318
319    def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R:
320        sql = self._make_sql_one_line(sql)
321
322        # TODO(paul): Maybe use 'info' and 'debug' for values?
323        sql_logger.debug("[SQL] {%s} %s", self.name, sql)
324
325        sql = self.database_engine.convert_param_style(sql)
326        if args:
327            try:
328                sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
329            except Exception:
330                # Don't let logging failures stop SQL from working
331                pass
332
333        start = time.time()
334
335        try:
336            with opentracing.start_active_span(
337                "db.query",
338                tags={
339                    opentracing.tags.DATABASE_TYPE: "sql",
340                    opentracing.tags.DATABASE_STATEMENT: sql,
341                },
342            ):
343                return func(sql, *args)
344        except Exception as e:
345            sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e)
346            raise
347        finally:
348            secs = time.time() - start
349            sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
350            sql_query_timer.labels(sql.split()[0]).observe(secs)
351
352    def close(self) -> None:
353        self.txn.close()
354
355    def __enter__(self) -> "LoggingTransaction":
356        return self
357
358    def __exit__(self, exc_type, exc_value, traceback):
359        self.close()
360
361
362class PerformanceCounters:
363    def __init__(self):
364        self.current_counters = {}
365        self.previous_counters = {}
366
367    def update(self, key: str, duration_secs: float) -> None:
368        count, cum_time = self.current_counters.get(key, (0, 0))
369        count += 1
370        cum_time += duration_secs
371        self.current_counters[key] = (count, cum_time)
372
373    def interval(self, interval_duration_secs: float, limit: int = 3) -> str:
374        counters = []
375        for name, (count, cum_time) in self.current_counters.items():
376            prev_count, prev_time = self.previous_counters.get(name, (0, 0))
377            counters.append(
378                (
379                    (cum_time - prev_time) / interval_duration_secs,
380                    count - prev_count,
381                    name,
382                )
383            )
384
385        self.previous_counters = dict(self.current_counters)
386
387        counters.sort(reverse=True)
388
389        top_n_counters = ", ".join(
390            "%s(%d): %.3f%%" % (name, count, 100 * ratio)
391            for ratio, count, name in counters[:limit]
392        )
393
394        return top_n_counters
395
396
397class DatabasePool:
398    """Wraps a single physical database and connection pool.
399
400    A single database may be used by multiple data stores.
401    """
402
403    _TXN_ID = 0
404
405    def __init__(
406        self,
407        hs: "HomeServer",
408        database_config: DatabaseConnectionConfig,
409        engine: BaseDatabaseEngine,
410    ):
411        self.hs = hs
412        self._clock = hs.get_clock()
413        self._txn_limit = database_config.config.get("txn_limit", 0)
414        self._database_config = database_config
415        self._db_pool = make_pool(hs.get_reactor(), database_config, engine)
416
417        self.updates = BackgroundUpdater(hs, self)
418
419        self._previous_txn_total_time = 0.0
420        self._current_txn_total_time = 0.0
421        self._previous_loop_ts = 0.0
422
423        # Transaction counter: key is the twisted thread id, value is the current count
424        self._txn_counters: Dict[int, int] = defaultdict(int)
425
426        # TODO(paul): These can eventually be removed once the metrics code
427        #   is running in mainline, and we have some nice monitoring frontends
428        #   to watch it
429        self._txn_perf_counters = PerformanceCounters()
430
431        self.engine = engine
432
433        # A set of tables that are not safe to use native upserts in.
434        self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
435
436        # We add the user_directory_search table to the blacklist on SQLite
437        # because the existing search table does not have an index, making it
438        # unsafe to use native upserts.
439        if isinstance(self.engine, Sqlite3Engine):
440            self._unsafe_to_upsert_tables.add("user_directory_search")
441
442        if self.engine.can_native_upsert:
443            # Check ASAP (and then later, every 1s) to see if we have finished
444            # background updates of tables that aren't safe to update.
445            self._clock.call_later(
446                0.0,
447                run_as_background_process,
448                "upsert_safety_check",
449                self._check_safe_to_upsert,
450            )
451
452    def name(self) -> str:
453        "Return the name of this database"
454        return self._database_config.name
455
456    def is_running(self) -> bool:
457        """Is the database pool currently running"""
458        return self._db_pool.running
459
460    async def _check_safe_to_upsert(self) -> None:
461        """
462        Is it safe to use native UPSERT?
463
464        If there are background updates, we will need to wait, as they may be
465        the addition of indexes that set the UNIQUE constraint that we require.
466
467        If the background updates have not completed, wait 15 sec and check again.
468        """
469        updates = await self.simple_select_list(
470            "background_updates",
471            keyvalues=None,
472            retcols=["update_name"],
473            desc="check_background_updates",
474        )
475        updates = [x["update_name"] for x in updates]
476
477        for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
478            if update_name not in updates:
479                logger.debug("Now safe to upsert in %s", table)
480                self._unsafe_to_upsert_tables.discard(table)
481
482        # If there's any updates still running, reschedule to run.
483        if updates:
484            self._clock.call_later(
485                15.0,
486                run_as_background_process,
487                "upsert_safety_check",
488                self._check_safe_to_upsert,
489            )
490
491    def start_profiling(self) -> None:
492        self._previous_loop_ts = monotonic_time()
493
494        def loop():
495            curr = self._current_txn_total_time
496            prev = self._previous_txn_total_time
497            self._previous_txn_total_time = curr
498
499            time_now = monotonic_time()
500            time_then = self._previous_loop_ts
501            self._previous_loop_ts = time_now
502
503            duration = time_now - time_then
504            ratio = (curr - prev) / duration
505
506            top_three_counters = self._txn_perf_counters.interval(duration, limit=3)
507
508            perf_logger.debug(
509                "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters
510            )
511
512        self._clock.looping_call(loop, 10000)
513
514    def new_transaction(
515        self,
516        conn: LoggingDatabaseConnection,
517        desc: str,
518        after_callbacks: List[_CallbackListEntry],
519        exception_callbacks: List[_CallbackListEntry],
520        func: Callable[..., R],
521        *args: Any,
522        **kwargs: Any,
523    ) -> R:
524        """Start a new database transaction with the given connection.
525
526        Note: The given func may be called multiple times under certain
527        failure modes. This is normally fine when in a standard transaction,
528        but care must be taken if the connection is in `autocommit` mode that
529        the function will correctly handle being aborted and retried half way
530        through its execution.
531
532        Similarly, the arguments to `func` (`args`, `kwargs`) should not be generators,
533        since they could be evaluated multiple times (which would produce an empty
534        result on the second or subsequent evaluation). Likewise, the closure of `func`
535        must not reference any generators.  This method attempts to detect such usage
536        and will log an error.
537
538        Args:
539            conn
540            desc
541            after_callbacks
542            exception_callbacks
543            func
544            *args
545            **kwargs
546        """
547
548        # Robustness check: ensure that none of the arguments are generators, since that
549        # will fail if we have to repeat the transaction.
550        # For now, we just log an error, and hope that it works on the first attempt.
551        # TODO: raise an exception.
552        for i, arg in enumerate(args):
553            if inspect.isgenerator(arg):
554                logger.error(
555                    "Programming error: generator passed to new_transaction as "
556                    "argument %i to function %s",
557                    i,
558                    func,
559                )
560        for name, val in kwargs.items():
561            if inspect.isgenerator(val):
562                logger.error(
563                    "Programming error: generator passed to new_transaction as "
564                    "argument %s to function %s",
565                    name,
566                    func,
567                )
568        # also check variables referenced in func's closure
569        if inspect.isfunction(func):
570            f = cast(types.FunctionType, func)
571            if f.__closure__:
572                for i, cell in enumerate(f.__closure__):
573                    if inspect.isgenerator(cell.cell_contents):
574                        logger.error(
575                            "Programming error: function %s references generator %s "
576                            "via its closure",
577                            f,
578                            f.__code__.co_freevars[i],
579                        )
580
581        start = monotonic_time()
582        txn_id = self._TXN_ID
583
584        # We don't really need these to be unique, so lets stop it from
585        # growing really large.
586        self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
587
588        name = "%s-%x" % (desc, txn_id)
589
590        transaction_logger.debug("[TXN START] {%s}", name)
591
592        try:
593            i = 0
594            N = 5
595            while True:
596                cursor = conn.cursor(
597                    txn_name=name,
598                    after_callbacks=after_callbacks,
599                    exception_callbacks=exception_callbacks,
600                )
601                try:
602                    with opentracing.start_active_span(
603                        "db.txn",
604                        tags={
605                            opentracing.SynapseTags.DB_TXN_DESC: desc,
606                            opentracing.SynapseTags.DB_TXN_ID: name,
607                        },
608                    ):
609                        r = func(cursor, *args, **kwargs)
610                        opentracing.log_kv({"message": "commit"})
611                        conn.commit()
612                        return r
613                except self.engine.module.OperationalError as e:
614                    # This can happen if the database disappears mid
615                    # transaction.
616                    transaction_logger.warning(
617                        "[TXN OPERROR] {%s} %s %d/%d",
618                        name,
619                        e,
620                        i,
621                        N,
622                    )
623                    if i < N:
624                        i += 1
625                        try:
626                            with opentracing.start_active_span("db.rollback"):
627                                conn.rollback()
628                        except self.engine.module.Error as e1:
629                            transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1)
630                        continue
631                    raise
632                except self.engine.module.DatabaseError as e:
633                    if self.engine.is_deadlock(e):
634                        transaction_logger.warning(
635                            "[TXN DEADLOCK] {%s} %d/%d", name, i, N
636                        )
637                        if i < N:
638                            i += 1
639                            try:
640                                with opentracing.start_active_span("db.rollback"):
641                                    conn.rollback()
642                            except self.engine.module.Error as e1:
643                                transaction_logger.warning(
644                                    "[TXN EROLL] {%s} %s",
645                                    name,
646                                    e1,
647                                )
648                            continue
649                    raise
650                finally:
651                    # we're either about to retry with a new cursor, or we're about to
652                    # release the connection. Once we release the connection, it could
653                    # get used for another query, which might do a conn.rollback().
654                    #
655                    # In the latter case, even though that probably wouldn't affect the
656                    # results of this transaction, python's sqlite will reset all
657                    # statements on the connection [1], which will make our cursor
658                    # invalid [2].
659                    #
660                    # In any case, continuing to read rows after commit()ing seems
661                    # dubious from the PoV of ACID transactional semantics
662                    # (sqlite explicitly says that once you commit, you may see rows
663                    # from subsequent updates.)
664                    #
665                    # In psycopg2, cursors are essentially a client-side fabrication -
666                    # all the data is transferred to the client side when the statement
667                    # finishes executing - so in theory we could go on streaming results
668                    # from the cursor, but attempting to do so would make us
669                    # incompatible with sqlite, so let's make sure we're not doing that
670                    # by closing the cursor.
671                    #
672                    # (*named* cursors in psycopg2 are different and are proper server-
673                    # side things, but (a) we don't use them and (b) they are implicitly
674                    # closed by ending the transaction anyway.)
675                    #
676                    # In short, if we haven't finished with the cursor yet, that's a
677                    # problem waiting to bite us.
678                    #
679                    # TL;DR: we're done with the cursor, so we can close it.
680                    #
681                    # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465
682                    # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
683                    cursor.close()
684        except Exception as e:
685            transaction_logger.debug("[TXN FAIL] {%s} %s", name, e)
686            raise
687        finally:
688            end = monotonic_time()
689            duration = end - start
690
691            current_context().add_database_transaction(duration)
692
693            transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
694
695            self._current_txn_total_time += duration
696            self._txn_perf_counters.update(desc, duration)
697            sql_txn_timer.labels(desc).observe(duration)
698
699    async def runInteraction(
700        self,
701        desc: str,
702        func: Callable[..., R],
703        *args: Any,
704        db_autocommit: bool = False,
705        **kwargs: Any,
706    ) -> R:
707        """Starts a transaction on the database and runs a given function
708
709        Arguments:
710            desc: description of the transaction, for logging and metrics
711            func: callback function, which will be called with a
712                database transaction (twisted.enterprise.adbapi.Transaction) as
713                its first argument, followed by `args` and `kwargs`.
714
715            db_autocommit: Whether to run the function in "autocommit" mode,
716                i.e. outside of a transaction. This is useful for transactions
717                that are only a single query.
718
719                Currently, this is only implemented for Postgres. SQLite will still
720                run the function inside a transaction.
721
722                WARNING: This means that if func fails half way through then
723                the changes will *not* be rolled back. `func` may also get
724                called multiple times if the transaction is retried, so must
725                correctly handle that case.
726
727            args: positional args to pass to `func`
728            kwargs: named args to pass to `func`
729
730        Returns:
731            The result of func
732        """
733        after_callbacks: List[_CallbackListEntry] = []
734        exception_callbacks: List[_CallbackListEntry] = []
735
736        if not current_context():
737            logger.warning("Starting db txn '%s' from sentinel context", desc)
738
739        try:
740            with opentracing.start_active_span(f"db.{desc}"):
741                result = await self.runWithConnection(
742                    self.new_transaction,
743                    desc,
744                    after_callbacks,
745                    exception_callbacks,
746                    func,
747                    *args,
748                    db_autocommit=db_autocommit,
749                    **kwargs,
750                )
751
752            for after_callback, after_args, after_kwargs in after_callbacks:
753                after_callback(*after_args, **after_kwargs)
754        except Exception:
755            for after_callback, after_args, after_kwargs in exception_callbacks:
756                after_callback(*after_args, **after_kwargs)
757            raise
758
759        return cast(R, result)
760
761    async def runWithConnection(
762        self,
763        func: Callable[..., R],
764        *args: Any,
765        db_autocommit: bool = False,
766        **kwargs: Any,
767    ) -> R:
768        """Wraps the .runWithConnection() method on the underlying db_pool.
769
770        Arguments:
771            func: callback function, which will be called with a
772                database connection (twisted.enterprise.adbapi.Connection) as
773                its first argument, followed by `args` and `kwargs`.
774            args: positional args to pass to `func`
775            db_autocommit: Whether to run the function in "autocommit" mode,
776                i.e. outside of a transaction. This is useful for transaction
777                that are only a single query. Currently only affects postgres.
778            kwargs: named args to pass to `func`
779
780        Returns:
781            The result of func
782        """
783        curr_context = current_context()
784        if not curr_context:
785            logger.warning(
786                "Starting db connection from sentinel context: metrics will be lost"
787            )
788            parent_context = None
789        else:
790            assert isinstance(curr_context, LoggingContext)
791            parent_context = curr_context
792
793        start_time = monotonic_time()
794
795        def inner_func(conn, *args, **kwargs):
796            # We shouldn't be in a transaction. If we are then something
797            # somewhere hasn't committed after doing work. (This is likely only
798            # possible during startup, as `run*` will ensure changes are
799            # committed/rolled back before putting the connection back in the
800            # pool).
801            assert not self.engine.in_transaction(conn)
802
803            with LoggingContext(
804                str(curr_context), parent_context=parent_context
805            ) as context:
806                with opentracing.start_active_span(
807                    operation_name="db.connection",
808                ):
809                    sched_duration_sec = monotonic_time() - start_time
810                    sql_scheduling_timer.observe(sched_duration_sec)
811                    context.add_database_scheduled(sched_duration_sec)
812
813                    if self._txn_limit > 0:
814                        tid = self._db_pool.threadID()
815                        self._txn_counters[tid] += 1
816
817                        if self._txn_counters[tid] > self._txn_limit:
818                            logger.debug(
819                                "Reconnecting database connection over transaction limit"
820                            )
821                            conn.reconnect()
822                            opentracing.log_kv(
823                                {"message": "reconnected due to txn limit"}
824                            )
825                            self._txn_counters[tid] = 1
826
827                    if self.engine.is_connection_closed(conn):
828                        logger.debug("Reconnecting closed database connection")
829                        conn.reconnect()
830                        opentracing.log_kv({"message": "reconnected"})
831                        if self._txn_limit > 0:
832                            self._txn_counters[tid] = 1
833
834                    try:
835                        if db_autocommit:
836                            self.engine.attempt_to_set_autocommit(conn, True)
837
838                        db_conn = LoggingDatabaseConnection(
839                            conn, self.engine, "runWithConnection"
840                        )
841                        return func(db_conn, *args, **kwargs)
842                    finally:
843                        if db_autocommit:
844                            self.engine.attempt_to_set_autocommit(conn, False)
845
846        return await make_deferred_yieldable(
847            self._db_pool.runWithConnection(inner_func, *args, **kwargs)
848        )
849
850    @staticmethod
851    def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
852        """Converts a SQL cursor into an list of dicts.
853
854        Args:
855            cursor: The DBAPI cursor which has executed a query.
856        Returns:
857            A list of dicts where the key is the column header.
858        """
859        assert cursor.description is not None, "cursor.description was None"
860        col_headers = [intern(str(column[0])) for column in cursor.description]
861        results = [dict(zip(col_headers, row)) for row in cursor]
862        return results
863
864    @overload
865    async def execute(
866        self, desc: str, decoder: Literal[None], query: str, *args: Any
867    ) -> List[Tuple[Any, ...]]:
868        ...
869
870    @overload
871    async def execute(
872        self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
873    ) -> R:
874        ...
875
876    async def execute(
877        self,
878        desc: str,
879        decoder: Optional[Callable[[Cursor], R]],
880        query: str,
881        *args: Any,
882    ) -> R:
883        """Runs a single query for a result set.
884
885        Args:
886            desc: description of the transaction, for logging and metrics
887            decoder - The function which can resolve the cursor results to
888                something meaningful.
889            query - The query string to execute
890            *args - Query args.
891        Returns:
892            The result of decoder(results)
893        """
894
895        def interaction(txn):
896            txn.execute(query, args)
897            if decoder:
898                return decoder(txn)
899            else:
900                return txn.fetchall()
901
902        return await self.runInteraction(desc, interaction)
903
904    # "Simple" SQL API methods that operate on a single table with no JOINs,
905    # no complex WHERE clauses, just a dict of values for columns.
906
907    async def simple_insert(
908        self,
909        table: str,
910        values: Dict[str, Any],
911        desc: str = "simple_insert",
912    ) -> None:
913        """Executes an INSERT query on the named table.
914
915        Args:
916            table: string giving the table name
917            values: dict of new column names and values for them
918            desc: description of the transaction, for logging and metrics
919        """
920        await self.runInteraction(desc, self.simple_insert_txn, table, values)
921
922    @staticmethod
923    def simple_insert_txn(
924        txn: LoggingTransaction, table: str, values: Dict[str, Any]
925    ) -> None:
926        keys, vals = zip(*values.items())
927
928        sql = "INSERT INTO %s (%s) VALUES(%s)" % (
929            table,
930            ", ".join(k for k in keys),
931            ", ".join("?" for _ in keys),
932        )
933
934        txn.execute(sql, vals)
935
936    async def simple_insert_many(
937        self, table: str, values: List[Dict[str, Any]], desc: str
938    ) -> None:
939        """Executes an INSERT query on the named table.
940
941        The input is given as a list of dicts, with one dict per row.
942        Generally simple_insert_many_values should be preferred for new code.
943
944        Args:
945            table: string giving the table name
946            values: dict of new column names and values for them
947            desc: description of the transaction, for logging and metrics
948        """
949        await self.runInteraction(desc, self.simple_insert_many_txn, table, values)
950
951    @staticmethod
952    def simple_insert_many_txn(
953        txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
954    ) -> None:
955        """Executes an INSERT query on the named table.
956
957        The input is given as a list of dicts, with one dict per row.
958        Generally simple_insert_many_values_txn should be preferred for new code.
959
960        Args:
961            txn: The transaction to use.
962            table: string giving the table name
963            values: dict of new column names and values for them
964        """
965        if not values:
966            return
967
968        # This is a *slight* abomination to get a list of tuples of key names
969        # and a list of tuples of value names.
970        #
971        # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
972        #         => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
973        #
974        # The sort is to ensure that we don't rely on dictionary iteration
975        # order.
976        keys, vals = zip(
977            *(zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i)
978        )
979
980        for k in keys:
981            if k != keys[0]:
982                raise RuntimeError("All items must have the same keys")
983
984        return DatabasePool.simple_insert_many_values_txn(txn, table, keys[0], vals)
985
986    async def simple_insert_many_values(
987        self,
988        table: str,
989        keys: Collection[str],
990        values: Collection[Collection[Any]],
991        desc: str,
992    ) -> None:
993        """Executes an INSERT query on the named table.
994
995        The input is given as a list of rows, where each row is a list of values.
996        (Actually any iterable is fine.)
997
998        Args:
999            table: string giving the table name
1000            keys: list of column names
1001            values: for each row, a list of values in the same order as `keys`
1002            desc: description of the transaction, for logging and metrics
1003        """
1004        await self.runInteraction(
1005            desc, self.simple_insert_many_values_txn, table, keys, values
1006        )
1007
1008    @staticmethod
1009    def simple_insert_many_values_txn(
1010        txn: LoggingTransaction,
1011        table: str,
1012        keys: Collection[str],
1013        values: Iterable[Iterable[Any]],
1014    ) -> None:
1015        """Executes an INSERT query on the named table.
1016
1017        The input is given as a list of rows, where each row is a list of values.
1018        (Actually any iterable is fine.)
1019
1020        Args:
1021            txn: The transaction to use.
1022            table: string giving the table name
1023            keys: list of column names
1024            values: for each row, a list of values in the same order as `keys`
1025        """
1026
1027        if isinstance(txn.database_engine, PostgresEngine):
1028            # We use `execute_values` as it can be a lot faster than `execute_batch`,
1029            # but it's only available on postgres.
1030            sql = "INSERT INTO %s (%s) VALUES ?" % (
1031                table,
1032                ", ".join(k for k in keys),
1033            )
1034
1035            txn.execute_values(sql, values, fetch=False)
1036        else:
1037            sql = "INSERT INTO %s (%s) VALUES(%s)" % (
1038                table,
1039                ", ".join(k for k in keys),
1040                ", ".join("?" for _ in keys),
1041            )
1042
1043            txn.execute_batch(sql, values)
1044
1045    async def simple_upsert(
1046        self,
1047        table: str,
1048        keyvalues: Dict[str, Any],
1049        values: Dict[str, Any],
1050        insertion_values: Optional[Dict[str, Any]] = None,
1051        desc: str = "simple_upsert",
1052        lock: bool = True,
1053    ) -> bool:
1054        """
1055
1056        `lock` should generally be set to True (the default), but can be set
1057        to False if either of the following are true:
1058            1. there is a UNIQUE INDEX on the key columns. In this case a conflict
1059            will cause an IntegrityError in which case this function will retry
1060            the update.
1061            2. we somehow know that we are the only thread which will be updating
1062            this table.
1063        As an additional note, this parameter only matters for old SQLite versions
1064        because we will use native upserts otherwise.
1065
1066        Args:
1067            table: The table to upsert into
1068            keyvalues: The unique key columns and their new values
1069            values: The nonunique columns and their new values
1070            insertion_values: additional key/values to use only when inserting
1071            desc: description of the transaction, for logging and metrics
1072            lock: True to lock the table when doing the upsert.
1073        Returns:
1074            Returns True if a row was inserted or updated (i.e. if `values` is
1075            not empty then this always returns True)
1076        """
1077        insertion_values = insertion_values or {}
1078
1079        attempts = 0
1080        while True:
1081            try:
1082                # We can autocommit if we are going to use native upserts
1083                autocommit = (
1084                    self.engine.can_native_upsert
1085                    and table not in self._unsafe_to_upsert_tables
1086                )
1087
1088                return await self.runInteraction(
1089                    desc,
1090                    self.simple_upsert_txn,
1091                    table,
1092                    keyvalues,
1093                    values,
1094                    insertion_values,
1095                    lock=lock,
1096                    db_autocommit=autocommit,
1097                )
1098            except self.engine.module.IntegrityError as e:
1099                attempts += 1
1100                if attempts >= 5:
1101                    # don't retry forever, because things other than races
1102                    # can cause IntegrityErrors
1103                    raise
1104
1105                # presumably we raced with another transaction: let's retry.
1106                logger.warning(
1107                    "IntegrityError when upserting into %s; retrying: %s", table, e
1108                )
1109
1110    def simple_upsert_txn(
1111        self,
1112        txn: LoggingTransaction,
1113        table: str,
1114        keyvalues: Dict[str, Any],
1115        values: Dict[str, Any],
1116        insertion_values: Optional[Dict[str, Any]] = None,
1117        lock: bool = True,
1118    ) -> bool:
1119        """
1120        Pick the UPSERT method which works best on the platform. Either the
1121        native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
1122
1123        Args:
1124            txn: The transaction to use.
1125            table: The table to upsert into
1126            keyvalues: The unique key tables and their new values
1127            values: The nonunique columns and their new values
1128            insertion_values: additional key/values to use only when inserting
1129            lock: True to lock the table when doing the upsert.
1130        Returns:
1131            Returns True if a row was inserted or updated (i.e. if `values` is
1132            not empty then this always returns True)
1133        """
1134        insertion_values = insertion_values or {}
1135
1136        if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
1137            return self.simple_upsert_txn_native_upsert(
1138                txn, table, keyvalues, values, insertion_values=insertion_values
1139            )
1140        else:
1141            return self.simple_upsert_txn_emulated(
1142                txn,
1143                table,
1144                keyvalues,
1145                values,
1146                insertion_values=insertion_values,
1147                lock=lock,
1148            )
1149
1150    def simple_upsert_txn_emulated(
1151        self,
1152        txn: LoggingTransaction,
1153        table: str,
1154        keyvalues: Dict[str, Any],
1155        values: Dict[str, Any],
1156        insertion_values: Optional[Dict[str, Any]] = None,
1157        lock: bool = True,
1158    ) -> bool:
1159        """
1160        Args:
1161            table: The table to upsert into
1162            keyvalues: The unique key tables and their new values
1163            values: The nonunique columns and their new values
1164            insertion_values: additional key/values to use only when inserting
1165            lock: True to lock the table when doing the upsert.
1166        Returns:
1167            Returns True if a row was inserted or updated (i.e. if `values` is
1168            not empty then this always returns True)
1169        """
1170        insertion_values = insertion_values or {}
1171
1172        # We need to lock the table :(, unless we're *really* careful
1173        if lock:
1174            self.engine.lock_table(txn, table)
1175
1176        def _getwhere(key):
1177            # If the value we're passing in is None (aka NULL), we need to use
1178            # IS, not =, as NULL = NULL equals NULL (False).
1179            if keyvalues[key] is None:
1180                return "%s IS ?" % (key,)
1181            else:
1182                return "%s = ?" % (key,)
1183
1184        if not values:
1185            # If `values` is empty, then all of the values we care about are in
1186            # the unique key, so there is nothing to UPDATE. We can just do a
1187            # SELECT instead to see if it exists.
1188            sql = "SELECT 1 FROM %s WHERE %s" % (
1189                table,
1190                " AND ".join(_getwhere(k) for k in keyvalues),
1191            )
1192            sqlargs = list(keyvalues.values())
1193            txn.execute(sql, sqlargs)
1194            if txn.fetchall():
1195                # We have an existing record.
1196                return False
1197        else:
1198            # First try to update.
1199            sql = "UPDATE %s SET %s WHERE %s" % (
1200                table,
1201                ", ".join("%s = ?" % (k,) for k in values),
1202                " AND ".join(_getwhere(k) for k in keyvalues),
1203            )
1204            sqlargs = list(values.values()) + list(keyvalues.values())
1205
1206            txn.execute(sql, sqlargs)
1207            if txn.rowcount > 0:
1208                return True
1209
1210        # We didn't find any existing rows, so insert a new one
1211        allvalues: Dict[str, Any] = {}
1212        allvalues.update(keyvalues)
1213        allvalues.update(values)
1214        allvalues.update(insertion_values)
1215
1216        sql = "INSERT INTO %s (%s) VALUES (%s)" % (
1217            table,
1218            ", ".join(k for k in allvalues),
1219            ", ".join("?" for _ in allvalues),
1220        )
1221        txn.execute(sql, list(allvalues.values()))
1222        # successfully inserted
1223        return True
1224
1225    def simple_upsert_txn_native_upsert(
1226        self,
1227        txn: LoggingTransaction,
1228        table: str,
1229        keyvalues: Dict[str, Any],
1230        values: Dict[str, Any],
1231        insertion_values: Optional[Dict[str, Any]] = None,
1232    ) -> bool:
1233        """
1234        Use the native UPSERT functionality in PostgreSQL.
1235
1236        Args:
1237            table: The table to upsert into
1238            keyvalues: The unique key tables and their new values
1239            values: The nonunique columns and their new values
1240            insertion_values: additional key/values to use only when inserting
1241
1242        Returns:
1243            Returns True if a row was inserted or updated (i.e. if `values` is
1244            not empty then this always returns True)
1245        """
1246        allvalues: Dict[str, Any] = {}
1247        allvalues.update(keyvalues)
1248        allvalues.update(insertion_values or {})
1249
1250        if not values:
1251            latter = "NOTHING"
1252        else:
1253            allvalues.update(values)
1254            latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
1255
1256        sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
1257            table,
1258            ", ".join(k for k in allvalues),
1259            ", ".join("?" for _ in allvalues),
1260            ", ".join(k for k in keyvalues),
1261            latter,
1262        )
1263        txn.execute(sql, list(allvalues.values()))
1264
1265        return bool(txn.rowcount)
1266
1267    async def simple_upsert_many(
1268        self,
1269        table: str,
1270        key_names: Collection[str],
1271        key_values: Collection[Collection[Any]],
1272        value_names: Collection[str],
1273        value_values: Collection[Collection[Any]],
1274        desc: str,
1275    ) -> None:
1276        """
1277        Upsert, many times.
1278
1279        Args:
1280            table: The table to upsert into
1281            key_names: The key column names.
1282            key_values: A list of each row's key column values.
1283            value_names: The value column names
1284            value_values: A list of each row's value column values.
1285                Ignored if value_names is empty.
1286        """
1287
1288        # We can autocommit if we are going to use native upserts
1289        autocommit = (
1290            self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables
1291        )
1292
1293        return await self.runInteraction(
1294            desc,
1295            self.simple_upsert_many_txn,
1296            table,
1297            key_names,
1298            key_values,
1299            value_names,
1300            value_values,
1301            db_autocommit=autocommit,
1302        )
1303
1304    def simple_upsert_many_txn(
1305        self,
1306        txn: LoggingTransaction,
1307        table: str,
1308        key_names: Collection[str],
1309        key_values: Collection[Iterable[Any]],
1310        value_names: Collection[str],
1311        value_values: Iterable[Iterable[Any]],
1312    ) -> None:
1313        """
1314        Upsert, many times.
1315
1316        Args:
1317            table: The table to upsert into
1318            key_names: The key column names.
1319            key_values: A list of each row's key column values.
1320            value_names: The value column names
1321            value_values: A list of each row's value column values.
1322                Ignored if value_names is empty.
1323        """
1324        if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
1325            return self.simple_upsert_many_txn_native_upsert(
1326                txn, table, key_names, key_values, value_names, value_values
1327            )
1328        else:
1329            return self.simple_upsert_many_txn_emulated(
1330                txn, table, key_names, key_values, value_names, value_values
1331            )
1332
1333    def simple_upsert_many_txn_emulated(
1334        self,
1335        txn: LoggingTransaction,
1336        table: str,
1337        key_names: Iterable[str],
1338        key_values: Collection[Iterable[Any]],
1339        value_names: Collection[str],
1340        value_values: Iterable[Iterable[Any]],
1341    ) -> None:
1342        """
1343        Upsert, many times, but without native UPSERT support or batching.
1344
1345        Args:
1346            table: The table to upsert into
1347            key_names: The key column names.
1348            key_values: A list of each row's key column values.
1349            value_names: The value column names
1350            value_values: A list of each row's value column values.
1351                Ignored if value_names is empty.
1352        """
1353        # No value columns, therefore make a blank list so that the following
1354        # zip() works correctly.
1355        if not value_names:
1356            value_values = [() for x in range(len(key_values))]
1357
1358        for keyv, valv in zip(key_values, value_values):
1359            _keys = {x: y for x, y in zip(key_names, keyv)}
1360            _vals = {x: y for x, y in zip(value_names, valv)}
1361
1362            self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
1363
1364    def simple_upsert_many_txn_native_upsert(
1365        self,
1366        txn: LoggingTransaction,
1367        table: str,
1368        key_names: Collection[str],
1369        key_values: Collection[Iterable[Any]],
1370        value_names: Collection[str],
1371        value_values: Iterable[Iterable[Any]],
1372    ) -> None:
1373        """
1374        Upsert, many times, using batching where possible.
1375
1376        Args:
1377            table: The table to upsert into
1378            key_names: The key column names.
1379            key_values: A list of each row's key column values.
1380            value_names: The value column names
1381            value_values: A list of each row's value column values.
1382                Ignored if value_names is empty.
1383        """
1384        allnames: List[str] = []
1385        allnames.extend(key_names)
1386        allnames.extend(value_names)
1387
1388        if not value_names:
1389            # No value columns, therefore make a blank list so that the
1390            # following zip() works correctly.
1391            latter = "NOTHING"
1392            value_values = [() for x in range(len(key_values))]
1393        else:
1394            latter = "UPDATE SET " + ", ".join(
1395                k + "=EXCLUDED." + k for k in value_names
1396            )
1397
1398        args = []
1399
1400        for x, y in zip(key_values, value_values):
1401            args.append(tuple(x) + tuple(y))
1402
1403        if isinstance(txn.database_engine, PostgresEngine):
1404            # We use `execute_values` as it can be a lot faster than `execute_batch`,
1405            # but it's only available on postgres.
1406            sql = "INSERT INTO %s (%s) VALUES ? ON CONFLICT (%s) DO %s" % (
1407                table,
1408                ", ".join(k for k in allnames),
1409                ", ".join(key_names),
1410                latter,
1411            )
1412
1413            txn.execute_values(sql, args, fetch=False)
1414
1415        else:
1416            sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
1417                table,
1418                ", ".join(k for k in allnames),
1419                ", ".join("?" for _ in allnames),
1420                ", ".join(key_names),
1421                latter,
1422            )
1423
1424            return txn.execute_batch(sql, args)
1425
1426    @overload
1427    async def simple_select_one(
1428        self,
1429        table: str,
1430        keyvalues: Dict[str, Any],
1431        retcols: Collection[str],
1432        allow_none: Literal[False] = False,
1433        desc: str = "simple_select_one",
1434    ) -> Dict[str, Any]:
1435        ...
1436
1437    @overload
1438    async def simple_select_one(
1439        self,
1440        table: str,
1441        keyvalues: Dict[str, Any],
1442        retcols: Collection[str],
1443        allow_none: Literal[True] = True,
1444        desc: str = "simple_select_one",
1445    ) -> Optional[Dict[str, Any]]:
1446        ...
1447
1448    async def simple_select_one(
1449        self,
1450        table: str,
1451        keyvalues: Dict[str, Any],
1452        retcols: Collection[str],
1453        allow_none: bool = False,
1454        desc: str = "simple_select_one",
1455    ) -> Optional[Dict[str, Any]]:
1456        """Executes a SELECT query on the named table, which is expected to
1457        return a single row, returning multiple columns from it.
1458
1459        Args:
1460            table: string giving the table name
1461            keyvalues: dict of column names and values to select the row with
1462            retcols: list of strings giving the names of the columns to return
1463            allow_none: If true, return None instead of failing if the SELECT
1464                statement returns no rows
1465            desc: description of the transaction, for logging and metrics
1466        """
1467        return await self.runInteraction(
1468            desc,
1469            self.simple_select_one_txn,
1470            table,
1471            keyvalues,
1472            retcols,
1473            allow_none,
1474            db_autocommit=True,
1475        )
1476
1477    @overload
1478    async def simple_select_one_onecol(
1479        self,
1480        table: str,
1481        keyvalues: Dict[str, Any],
1482        retcol: str,
1483        allow_none: Literal[False] = False,
1484        desc: str = "simple_select_one_onecol",
1485    ) -> Any:
1486        ...
1487
1488    @overload
1489    async def simple_select_one_onecol(
1490        self,
1491        table: str,
1492        keyvalues: Dict[str, Any],
1493        retcol: str,
1494        allow_none: Literal[True] = True,
1495        desc: str = "simple_select_one_onecol",
1496    ) -> Optional[Any]:
1497        ...
1498
1499    async def simple_select_one_onecol(
1500        self,
1501        table: str,
1502        keyvalues: Dict[str, Any],
1503        retcol: str,
1504        allow_none: bool = False,
1505        desc: str = "simple_select_one_onecol",
1506    ) -> Optional[Any]:
1507        """Executes a SELECT query on the named table, which is expected to
1508        return a single row, returning a single column from it.
1509
1510        Args:
1511            table: string giving the table name
1512            keyvalues: dict of column names and values to select the row with
1513            retcol: string giving the name of the column to return
1514            allow_none: If true, return None instead of failing if the SELECT
1515                statement returns no rows
1516            desc: description of the transaction, for logging and metrics
1517        """
1518        return await self.runInteraction(
1519            desc,
1520            self.simple_select_one_onecol_txn,
1521            table,
1522            keyvalues,
1523            retcol,
1524            allow_none=allow_none,
1525            db_autocommit=True,
1526        )
1527
1528    @overload
1529    @classmethod
1530    def simple_select_one_onecol_txn(
1531        cls,
1532        txn: LoggingTransaction,
1533        table: str,
1534        keyvalues: Dict[str, Any],
1535        retcol: str,
1536        allow_none: Literal[False] = False,
1537    ) -> Any:
1538        ...
1539
1540    @overload
1541    @classmethod
1542    def simple_select_one_onecol_txn(
1543        cls,
1544        txn: LoggingTransaction,
1545        table: str,
1546        keyvalues: Dict[str, Any],
1547        retcol: str,
1548        allow_none: Literal[True] = True,
1549    ) -> Optional[Any]:
1550        ...
1551
1552    @classmethod
1553    def simple_select_one_onecol_txn(
1554        cls,
1555        txn: LoggingTransaction,
1556        table: str,
1557        keyvalues: Dict[str, Any],
1558        retcol: str,
1559        allow_none: bool = False,
1560    ) -> Optional[Any]:
1561        ret = cls.simple_select_onecol_txn(
1562            txn, table=table, keyvalues=keyvalues, retcol=retcol
1563        )
1564
1565        if ret:
1566            return ret[0]
1567        else:
1568            if allow_none:
1569                return None
1570            else:
1571                raise StoreError(404, "No row found")
1572
1573    @staticmethod
1574    def simple_select_onecol_txn(
1575        txn: LoggingTransaction,
1576        table: str,
1577        keyvalues: Dict[str, Any],
1578        retcol: str,
1579    ) -> List[Any]:
1580        sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
1581
1582        if keyvalues:
1583            sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
1584            txn.execute(sql, list(keyvalues.values()))
1585        else:
1586            txn.execute(sql)
1587
1588        return [r[0] for r in txn]
1589
1590    async def simple_select_onecol(
1591        self,
1592        table: str,
1593        keyvalues: Optional[Dict[str, Any]],
1594        retcol: str,
1595        desc: str = "simple_select_onecol",
1596    ) -> List[Any]:
1597        """Executes a SELECT query on the named table, which returns a list
1598        comprising of the values of the named column from the selected rows.
1599
1600        Args:
1601            table: table name
1602            keyvalues: column names and values to select the rows with
1603            retcol: column whos value we wish to retrieve.
1604            desc: description of the transaction, for logging and metrics
1605
1606        Returns:
1607            Results in a list
1608        """
1609        return await self.runInteraction(
1610            desc,
1611            self.simple_select_onecol_txn,
1612            table,
1613            keyvalues,
1614            retcol,
1615            db_autocommit=True,
1616        )
1617
1618    async def simple_select_list(
1619        self,
1620        table: str,
1621        keyvalues: Optional[Dict[str, Any]],
1622        retcols: Collection[str],
1623        desc: str = "simple_select_list",
1624    ) -> List[Dict[str, Any]]:
1625        """Executes a SELECT query on the named table, which may return zero or
1626        more rows, returning the result as a list of dicts.
1627
1628        Args:
1629            table: the table name
1630            keyvalues:
1631                column names and values to select the rows with, or None to not
1632                apply a WHERE clause.
1633            retcols: the names of the columns to return
1634            desc: description of the transaction, for logging and metrics
1635
1636        Returns:
1637            A list of dictionaries.
1638        """
1639        return await self.runInteraction(
1640            desc,
1641            self.simple_select_list_txn,
1642            table,
1643            keyvalues,
1644            retcols,
1645            db_autocommit=True,
1646        )
1647
1648    @classmethod
1649    def simple_select_list_txn(
1650        cls,
1651        txn: LoggingTransaction,
1652        table: str,
1653        keyvalues: Optional[Dict[str, Any]],
1654        retcols: Iterable[str],
1655    ) -> List[Dict[str, Any]]:
1656        """Executes a SELECT query on the named table, which may return zero or
1657        more rows, returning the result as a list of dicts.
1658
1659        Args:
1660            txn: Transaction object
1661            table: the table name
1662            keyvalues:
1663                column names and values to select the rows with, or None to not
1664                apply a WHERE clause.
1665            retcols: the names of the columns to return
1666        """
1667        if keyvalues:
1668            sql = "SELECT %s FROM %s WHERE %s" % (
1669                ", ".join(retcols),
1670                table,
1671                " AND ".join("%s = ?" % (k,) for k in keyvalues),
1672            )
1673            txn.execute(sql, list(keyvalues.values()))
1674        else:
1675            sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
1676            txn.execute(sql)
1677
1678        return cls.cursor_to_dict(txn)
1679
1680    async def simple_select_many_batch(
1681        self,
1682        table: str,
1683        column: str,
1684        iterable: Iterable[Any],
1685        retcols: Collection[str],
1686        keyvalues: Optional[Dict[str, Any]] = None,
1687        desc: str = "simple_select_many_batch",
1688        batch_size: int = 100,
1689    ) -> List[Any]:
1690        """Executes a SELECT query on the named table, which may return zero or
1691        more rows, returning the result as a list of dicts.
1692
1693        Filters rows by whether the value of `column` is in `iterable`.
1694
1695        Args:
1696            table: string giving the table name
1697            column: column name to test for inclusion against `iterable`
1698            iterable: list
1699            retcols: list of strings giving the names of the columns to return
1700            keyvalues: dict of column names and values to select the rows with
1701            desc: description of the transaction, for logging and metrics
1702            batch_size: the number of rows for each select query
1703        """
1704        keyvalues = keyvalues or {}
1705
1706        results: List[Dict[str, Any]] = []
1707
1708        for chunk in batch_iter(iterable, batch_size):
1709            rows = await self.runInteraction(
1710                desc,
1711                self.simple_select_many_txn,
1712                table,
1713                column,
1714                chunk,
1715                keyvalues,
1716                retcols,
1717                db_autocommit=True,
1718            )
1719
1720            results.extend(rows)
1721
1722        return results
1723
1724    @classmethod
1725    def simple_select_many_txn(
1726        cls,
1727        txn: LoggingTransaction,
1728        table: str,
1729        column: str,
1730        iterable: Collection[Any],
1731        keyvalues: Dict[str, Any],
1732        retcols: Iterable[str],
1733    ) -> List[Dict[str, Any]]:
1734        """Executes a SELECT query on the named table, which may return zero or
1735        more rows, returning the result as a list of dicts.
1736
1737        Filters rows by whether the value of `column` is in `iterable`.
1738
1739        Args:
1740            txn: Transaction object
1741            table: string giving the table name
1742            column: column name to test for inclusion against `iterable`
1743            iterable: list
1744            keyvalues: dict of column names and values to select the rows with
1745            retcols: list of strings giving the names of the columns to return
1746        """
1747        if not iterable:
1748            return []
1749
1750        clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
1751        clauses = [clause]
1752
1753        for key, value in keyvalues.items():
1754            clauses.append("%s = ?" % (key,))
1755            values.append(value)
1756
1757        sql = "SELECT %s FROM %s WHERE %s" % (
1758            ", ".join(retcols),
1759            table,
1760            " AND ".join(clauses),
1761        )
1762
1763        txn.execute(sql, values)
1764        return cls.cursor_to_dict(txn)
1765
1766    async def simple_update(
1767        self,
1768        table: str,
1769        keyvalues: Dict[str, Any],
1770        updatevalues: Dict[str, Any],
1771        desc: str,
1772    ) -> int:
1773        return await self.runInteraction(
1774            desc, self.simple_update_txn, table, keyvalues, updatevalues
1775        )
1776
1777    @staticmethod
1778    def simple_update_txn(
1779        txn: LoggingTransaction,
1780        table: str,
1781        keyvalues: Dict[str, Any],
1782        updatevalues: Dict[str, Any],
1783    ) -> int:
1784        if keyvalues:
1785            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
1786        else:
1787            where = ""
1788
1789        update_sql = "UPDATE %s SET %s %s" % (
1790            table,
1791            ", ".join("%s = ?" % (k,) for k in updatevalues),
1792            where,
1793        )
1794
1795        txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
1796
1797        return txn.rowcount
1798
1799    async def simple_update_one(
1800        self,
1801        table: str,
1802        keyvalues: Dict[str, Any],
1803        updatevalues: Dict[str, Any],
1804        desc: str = "simple_update_one",
1805    ) -> None:
1806        """Executes an UPDATE query on the named table, setting new values for
1807        columns in a row matching the key values.
1808
1809        Args:
1810            table: string giving the table name
1811            keyvalues: dict of column names and values to select the row with
1812            updatevalues: dict giving column names and values to update
1813            desc: description of the transaction, for logging and metrics
1814        """
1815        await self.runInteraction(
1816            desc,
1817            self.simple_update_one_txn,
1818            table,
1819            keyvalues,
1820            updatevalues,
1821            db_autocommit=True,
1822        )
1823
1824    @classmethod
1825    def simple_update_one_txn(
1826        cls,
1827        txn: LoggingTransaction,
1828        table: str,
1829        keyvalues: Dict[str, Any],
1830        updatevalues: Dict[str, Any],
1831    ) -> None:
1832        rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
1833
1834        if rowcount == 0:
1835            raise StoreError(404, "No row found (%s)" % (table,))
1836        if rowcount > 1:
1837            raise StoreError(500, "More than one row matched (%s)" % (table,))
1838
1839    # Ideally we could use the overload decorator here to specify that the
1840    # return type is only optional if allow_none is True, but this does not work
1841    # when you call a static method from an instance.
1842    # See https://github.com/python/mypy/issues/7781
1843    @staticmethod
1844    def simple_select_one_txn(
1845        txn: LoggingTransaction,
1846        table: str,
1847        keyvalues: Dict[str, Any],
1848        retcols: Collection[str],
1849        allow_none: bool = False,
1850    ) -> Optional[Dict[str, Any]]:
1851        select_sql = "SELECT %s FROM %s WHERE %s" % (
1852            ", ".join(retcols),
1853            table,
1854            " AND ".join("%s = ?" % (k,) for k in keyvalues),
1855        )
1856
1857        txn.execute(select_sql, list(keyvalues.values()))
1858        row = txn.fetchone()
1859
1860        if not row:
1861            if allow_none:
1862                return None
1863            raise StoreError(404, "No row found (%s)" % (table,))
1864        if txn.rowcount > 1:
1865            raise StoreError(500, "More than one row matched (%s)" % (table,))
1866
1867        return dict(zip(retcols, row))
1868
1869    async def simple_delete_one(
1870        self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
1871    ) -> None:
1872        """Executes a DELETE query on the named table, expecting to delete a
1873        single row.
1874
1875        Args:
1876            table: string giving the table name
1877            keyvalues: dict of column names and values to select the row with
1878            desc: description of the transaction, for logging and metrics
1879        """
1880        await self.runInteraction(
1881            desc,
1882            self.simple_delete_one_txn,
1883            table,
1884            keyvalues,
1885            db_autocommit=True,
1886        )
1887
1888    @staticmethod
1889    def simple_delete_one_txn(
1890        txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
1891    ) -> None:
1892        """Executes a DELETE query on the named table, expecting to delete a
1893        single row.
1894
1895        Args:
1896            table: string giving the table name
1897            keyvalues: dict of column names and values to select the row with
1898        """
1899        sql = "DELETE FROM %s WHERE %s" % (
1900            table,
1901            " AND ".join("%s = ?" % (k,) for k in keyvalues),
1902        )
1903
1904        txn.execute(sql, list(keyvalues.values()))
1905        if txn.rowcount == 0:
1906            raise StoreError(404, "No row found (%s)" % (table,))
1907        if txn.rowcount > 1:
1908            raise StoreError(500, "More than one row matched (%s)" % (table,))
1909
1910    async def simple_delete(
1911        self, table: str, keyvalues: Dict[str, Any], desc: str
1912    ) -> int:
1913        """Executes a DELETE query on the named table.
1914
1915        Filters rows by the key-value pairs.
1916
1917        Args:
1918            table: string giving the table name
1919            keyvalues: dict of column names and values to select the row with
1920            desc: description of the transaction, for logging and metrics
1921
1922        Returns:
1923            The number of deleted rows.
1924        """
1925        return await self.runInteraction(
1926            desc, self.simple_delete_txn, table, keyvalues, db_autocommit=True
1927        )
1928
1929    @staticmethod
1930    def simple_delete_txn(
1931        txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
1932    ) -> int:
1933        """Executes a DELETE query on the named table.
1934
1935        Filters rows by the key-value pairs.
1936
1937        Args:
1938            table: string giving the table name
1939            keyvalues: dict of column names and values to select the row with
1940
1941        Returns:
1942            The number of deleted rows.
1943        """
1944        sql = "DELETE FROM %s WHERE %s" % (
1945            table,
1946            " AND ".join("%s = ?" % (k,) for k in keyvalues),
1947        )
1948
1949        txn.execute(sql, list(keyvalues.values()))
1950        return txn.rowcount
1951
1952    async def simple_delete_many(
1953        self,
1954        table: str,
1955        column: str,
1956        iterable: Collection[Any],
1957        keyvalues: Dict[str, Any],
1958        desc: str,
1959    ) -> int:
1960        """Executes a DELETE query on the named table.
1961
1962        Filters rows by if value of `column` is in `iterable`.
1963
1964        Args:
1965            table: string giving the table name
1966            column: column name to test for inclusion against `iterable`
1967            iterable: list of values to match against `column`. NB cannot be a generator
1968                as it may be evaluated multiple times.
1969            keyvalues: dict of column names and values to select the rows with
1970            desc: description of the transaction, for logging and metrics
1971
1972        Returns:
1973            Number rows deleted
1974        """
1975        return await self.runInteraction(
1976            desc,
1977            self.simple_delete_many_txn,
1978            table,
1979            column,
1980            iterable,
1981            keyvalues,
1982            db_autocommit=True,
1983        )
1984
1985    @staticmethod
1986    def simple_delete_many_txn(
1987        txn: LoggingTransaction,
1988        table: str,
1989        column: str,
1990        values: Collection[Any],
1991        keyvalues: Dict[str, Any],
1992    ) -> int:
1993        """Executes a DELETE query on the named table.
1994
1995        Deletes the rows:
1996          - whose value of `column` is in `values`; AND
1997          - that match extra column-value pairs specified in `keyvalues`.
1998
1999        Args:
2000            txn: Transaction object
2001            table: string giving the table name
2002            column: column name to test for inclusion against `values`
2003            values: values of `column` which choose rows to delete
2004            keyvalues: dict of extra column names and values to select the rows
2005                with. They will be ANDed together with the main predicate.
2006
2007        Returns:
2008            Number rows deleted
2009        """
2010        if not values:
2011            return 0
2012
2013        sql = "DELETE FROM %s" % table
2014
2015        clause, values = make_in_list_sql_clause(txn.database_engine, column, values)
2016        clauses = [clause]
2017
2018        for key, value in keyvalues.items():
2019            clauses.append("%s = ?" % (key,))
2020            values.append(value)
2021
2022        if clauses:
2023            sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
2024        txn.execute(sql, values)
2025
2026        return txn.rowcount
2027
2028    def get_cache_dict(
2029        self,
2030        db_conn: LoggingDatabaseConnection,
2031        table: str,
2032        entity_column: str,
2033        stream_column: str,
2034        max_value: int,
2035        limit: int = 100000,
2036    ) -> Tuple[Dict[Any, int], int]:
2037        # Fetch a mapping of room_id -> max stream position for "recent" rooms.
2038        # It doesn't really matter how many we get, the StreamChangeCache will
2039        # do the right thing to ensure it respects the max size of cache.
2040        sql = (
2041            "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
2042            " WHERE %(stream)s > ? - %(limit)s"
2043            " GROUP BY %(entity)s"
2044        ) % {
2045            "table": table,
2046            "entity": entity_column,
2047            "stream": stream_column,
2048            "limit": limit,
2049        }
2050
2051        txn = db_conn.cursor(txn_name="get_cache_dict")
2052        txn.execute(sql, (int(max_value),))
2053
2054        cache = {row[0]: int(row[1]) for row in txn}
2055
2056        txn.close()
2057
2058        if cache:
2059            min_val = min(cache.values())
2060        else:
2061            min_val = max_value
2062
2063        return cache, min_val
2064
2065    @classmethod
2066    def simple_select_list_paginate_txn(
2067        cls,
2068        txn: LoggingTransaction,
2069        table: str,
2070        orderby: str,
2071        start: int,
2072        limit: int,
2073        retcols: Iterable[str],
2074        filters: Optional[Dict[str, Any]] = None,
2075        keyvalues: Optional[Dict[str, Any]] = None,
2076        exclude_keyvalues: Optional[Dict[str, Any]] = None,
2077        order_direction: str = "ASC",
2078    ) -> List[Dict[str, Any]]:
2079        """
2080        Executes a SELECT query on the named table with start and limit,
2081        of row numbers, which may return zero or number of rows from start to limit,
2082        returning the result as a list of dicts.
2083
2084        Use `filters` to search attributes using SQL wildcards and/or `keyvalues` to
2085        select attributes with exact matches. All constraints are joined together
2086        using 'AND'.
2087
2088        Args:
2089            txn: Transaction object
2090            table: the table name
2091            orderby: Column to order the results by.
2092            start: Index to begin the query at.
2093            limit: Number of results to return.
2094            retcols: the names of the columns to return
2095            filters:
2096                column names and values to filter the rows with, or None to not
2097                apply a WHERE ? LIKE ? clause.
2098            keyvalues:
2099                column names and values to select the rows with, or None to not
2100                apply a WHERE key = value clause.
2101            exclude_keyvalues:
2102                column names and values to exclude rows with, or None to not
2103                apply a WHERE key != value clause.
2104            order_direction: Whether the results should be ordered "ASC" or "DESC".
2105
2106        Returns:
2107            The result as a list of dictionaries.
2108        """
2109        if order_direction not in ["ASC", "DESC"]:
2110            raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
2111
2112        where_clause = "WHERE " if filters or keyvalues or exclude_keyvalues else ""
2113        arg_list: List[Any] = []
2114        if filters:
2115            where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
2116            arg_list += list(filters.values())
2117        where_clause += " AND " if filters and keyvalues else ""
2118        if keyvalues:
2119            where_clause += " AND ".join("%s = ?" % (k,) for k in keyvalues)
2120            arg_list += list(keyvalues.values())
2121        if exclude_keyvalues:
2122            where_clause += " AND ".join("%s != ?" % (k,) for k in exclude_keyvalues)
2123            arg_list += list(exclude_keyvalues.values())
2124
2125        sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
2126            ", ".join(retcols),
2127            table,
2128            where_clause,
2129            orderby,
2130            order_direction,
2131        )
2132        txn.execute(sql, arg_list + [limit, start])
2133
2134        return cls.cursor_to_dict(txn)
2135
2136    async def simple_search_list(
2137        self,
2138        table: str,
2139        term: Optional[str],
2140        col: str,
2141        retcols: Collection[str],
2142        desc="simple_search_list",
2143    ) -> Optional[List[Dict[str, Any]]]:
2144        """Executes a SELECT query on the named table, which may return zero or
2145        more rows, returning the result as a list of dicts.
2146
2147        Args:
2148            table: the table name
2149            term: term for searching the table matched to a column.
2150            col: column to query term should be matched to
2151            retcols: the names of the columns to return
2152
2153        Returns:
2154            A list of dictionaries or None.
2155        """
2156
2157        return await self.runInteraction(
2158            desc,
2159            self.simple_search_list_txn,
2160            table,
2161            term,
2162            col,
2163            retcols,
2164            db_autocommit=True,
2165        )
2166
2167    @classmethod
2168    def simple_search_list_txn(
2169        cls,
2170        txn: LoggingTransaction,
2171        table: str,
2172        term: Optional[str],
2173        col: str,
2174        retcols: Iterable[str],
2175    ) -> Optional[List[Dict[str, Any]]]:
2176        """Executes a SELECT query on the named table, which may return zero or
2177        more rows, returning the result as a list of dicts.
2178
2179        Args:
2180            txn: Transaction object
2181            table: the table name
2182            term: term for searching the table matched to a column.
2183            col: column to query term should be matched to
2184            retcols: the names of the columns to return
2185
2186        Returns:
2187            None if no term is given, otherwise a list of dictionaries.
2188        """
2189        if term:
2190            sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
2191            termvalues = ["%%" + term + "%%"]
2192            txn.execute(sql, termvalues)
2193        else:
2194            return None
2195
2196        return cls.cursor_to_dict(txn)
2197
2198
2199def make_in_list_sql_clause(
2200    database_engine: BaseDatabaseEngine, column: str, iterable: Collection[Any]
2201) -> Tuple[str, list]:
2202    """Returns an SQL clause that checks the given column is in the iterable.
2203
2204    On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
2205    it expands to `column = ANY(?)`. While both DBs support the `IN` form,
2206    using the `ANY` form on postgres means that it views queries with
2207    different length iterables as the same, helping the query stats.
2208
2209    Args:
2210        database_engine
2211        column: Name of the column
2212        iterable: The values to check the column against.
2213
2214    Returns:
2215        A tuple of SQL query and the args
2216    """
2217
2218    if database_engine.supports_using_any_list:
2219        # This should hopefully be faster, but also makes postgres query
2220        # stats easier to understand.
2221        return "%s = ANY(?)" % (column,), [list(iterable)]
2222    else:
2223        return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
2224
2225
2226KV = TypeVar("KV")
2227
2228
2229def make_tuple_comparison_clause(keys: List[Tuple[str, KV]]) -> Tuple[str, List[KV]]:
2230    """Returns a tuple comparison SQL clause
2231
2232    Builds a SQL clause that looks like "(a, b) > (?, ?)"
2233
2234    Args:
2235        keys: A set of (column, value) pairs to be compared.
2236
2237    Returns:
2238        A tuple of SQL query and the args
2239    """
2240    return (
2241        "(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)),
2242        [k[1] for k in keys],
2243    )
2244