1# Copyright 2014-2016 OpenMarket Ltd 2# Copyright 2017-2018 New Vector Ltd 3# Copyright 2019 The Matrix.org Foundation C.I.C. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16import inspect 17import logging 18import time 19import types 20from collections import defaultdict 21from sys import intern 22from time import monotonic as monotonic_time 23from typing import ( 24 TYPE_CHECKING, 25 Any, 26 Callable, 27 Collection, 28 Dict, 29 Iterable, 30 Iterator, 31 List, 32 Optional, 33 Tuple, 34 TypeVar, 35 cast, 36 overload, 37) 38 39import attr 40from prometheus_client import Histogram 41from typing_extensions import Literal 42 43from twisted.enterprise import adbapi 44 45from synapse.api.errors import StoreError 46from synapse.config.database import DatabaseConnectionConfig 47from synapse.logging import opentracing 48from synapse.logging.context import ( 49 LoggingContext, 50 current_context, 51 make_deferred_yieldable, 52) 53from synapse.metrics import register_threadpool 54from synapse.metrics.background_process_metrics import run_as_background_process 55from synapse.storage.background_updates import BackgroundUpdater 56from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine 57from synapse.storage.types import Connection, Cursor 58from synapse.util.iterutils import batch_iter 59 60if TYPE_CHECKING: 61 from synapse.server import HomeServer 62 63# python 3 does not have a maximum int value 64MAX_TXN_ID = 2 ** 63 - 1 65 66logger = logging.getLogger(__name__) 67 68sql_logger = logging.getLogger("synapse.storage.SQL") 69transaction_logger = logging.getLogger("synapse.storage.txn") 70perf_logger = logging.getLogger("synapse.storage.TIME") 71 72sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec") 73 74sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"]) 75sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"]) 76 77 78# Unique indexes which have been added in background updates. Maps from table name 79# to the name of the background update which added the unique index to that table. 80# 81# This is used by the upsert logic to figure out which tables are safe to do a proper 82# UPSERT on: until the relevant background update has completed, we 83# have to emulate an upsert by locking the table. 84# 85UNIQUE_INDEX_BACKGROUND_UPDATES = { 86 "user_ips": "user_ips_device_unique_index", 87 "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx", 88 "device_lists_remote_cache": "device_lists_remote_cache_unique_idx", 89 "event_search": "event_search_event_id_idx", 90} 91 92 93def make_pool( 94 reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine 95) -> adbapi.ConnectionPool: 96 """Get the connection pool for the database.""" 97 98 # By default enable `cp_reconnect`. We need to fiddle with db_args in case 99 # someone has explicitly set `cp_reconnect`. 100 db_args = dict(db_config.config.get("args", {})) 101 db_args.setdefault("cp_reconnect", True) 102 103 def _on_new_connection(conn): 104 # Ensure we have a logging context so we can correctly track queries, 105 # etc. 106 with LoggingContext("db.on_new_connection"): 107 engine.on_new_connection( 108 LoggingDatabaseConnection(conn, engine, "on_new_connection") 109 ) 110 111 connection_pool = adbapi.ConnectionPool( 112 db_config.config["name"], 113 cp_reactor=reactor, 114 cp_openfun=_on_new_connection, 115 **db_args, 116 ) 117 118 register_threadpool(f"database-{db_config.name}", connection_pool.threadpool) 119 120 return connection_pool 121 122 123def make_conn( 124 db_config: DatabaseConnectionConfig, 125 engine: BaseDatabaseEngine, 126 default_txn_name: str, 127) -> "LoggingDatabaseConnection": 128 """Make a new connection to the database and return it. 129 130 Returns: 131 Connection 132 """ 133 134 db_params = { 135 k: v 136 for k, v in db_config.config.get("args", {}).items() 137 if not k.startswith("cp_") 138 } 139 native_db_conn = engine.module.connect(**db_params) 140 db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name) 141 142 engine.on_new_connection(db_conn) 143 return db_conn 144 145 146@attr.s(slots=True) 147class LoggingDatabaseConnection: 148 """A wrapper around a database connection that returns `LoggingTransaction` 149 as its cursor class. 150 151 This is mainly used on startup to ensure that queries get logged correctly 152 """ 153 154 conn = attr.ib(type=Connection) 155 engine = attr.ib(type=BaseDatabaseEngine) 156 default_txn_name = attr.ib(type=str) 157 158 def cursor( 159 self, *, txn_name=None, after_callbacks=None, exception_callbacks=None 160 ) -> "LoggingTransaction": 161 if not txn_name: 162 txn_name = self.default_txn_name 163 164 return LoggingTransaction( 165 self.conn.cursor(), 166 name=txn_name, 167 database_engine=self.engine, 168 after_callbacks=after_callbacks, 169 exception_callbacks=exception_callbacks, 170 ) 171 172 def close(self) -> None: 173 self.conn.close() 174 175 def commit(self) -> None: 176 self.conn.commit() 177 178 def rollback(self) -> None: 179 self.conn.rollback() 180 181 def __enter__(self) -> "LoggingDatabaseConnection": 182 self.conn.__enter__() 183 return self 184 185 def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]: 186 return self.conn.__exit__(exc_type, exc_value, traceback) 187 188 # Proxy through any unknown lookups to the DB conn class. 189 def __getattr__(self, name): 190 return getattr(self.conn, name) 191 192 193# The type of entry which goes on our after_callbacks and exception_callbacks lists. 194_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]] 195 196 197R = TypeVar("R") 198 199 200class LoggingTransaction: 201 """An object that almost-transparently proxies for the 'txn' object 202 passed to the constructor. Adds logging and metrics to the .execute() 203 method. 204 205 Args: 206 txn: The database transaction object to wrap. 207 name: The name of this transactions for logging. 208 database_engine 209 after_callbacks: A list that callbacks will be appended to 210 that have been added by `call_after` which should be run on 211 successful completion of the transaction. None indicates that no 212 callbacks should be allowed to be scheduled to run. 213 exception_callbacks: A list that callbacks will be appended 214 to that have been added by `call_on_exception` which should be run 215 if transaction ends with an error. None indicates that no callbacks 216 should be allowed to be scheduled to run. 217 """ 218 219 __slots__ = [ 220 "txn", 221 "name", 222 "database_engine", 223 "after_callbacks", 224 "exception_callbacks", 225 ] 226 227 def __init__( 228 self, 229 txn: Cursor, 230 name: str, 231 database_engine: BaseDatabaseEngine, 232 after_callbacks: Optional[List[_CallbackListEntry]] = None, 233 exception_callbacks: Optional[List[_CallbackListEntry]] = None, 234 ): 235 self.txn = txn 236 self.name = name 237 self.database_engine = database_engine 238 self.after_callbacks = after_callbacks 239 self.exception_callbacks = exception_callbacks 240 241 def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any): 242 """Call the given callback on the main twisted thread after the 243 transaction has finished. Used to invalidate the caches on the 244 correct thread. 245 """ 246 # if self.after_callbacks is None, that means that whatever constructed the 247 # LoggingTransaction isn't expecting there to be any callbacks; assert that 248 # is not the case. 249 assert self.after_callbacks is not None 250 self.after_callbacks.append((callback, args, kwargs)) 251 252 def call_on_exception( 253 self, callback: Callable[..., object], *args: Any, **kwargs: Any 254 ): 255 # if self.exception_callbacks is None, that means that whatever constructed the 256 # LoggingTransaction isn't expecting there to be any callbacks; assert that 257 # is not the case. 258 assert self.exception_callbacks is not None 259 self.exception_callbacks.append((callback, args, kwargs)) 260 261 def fetchone(self) -> Optional[Tuple]: 262 return self.txn.fetchone() 263 264 def fetchmany(self, size: Optional[int] = None) -> List[Tuple]: 265 return self.txn.fetchmany(size=size) 266 267 def fetchall(self) -> List[Tuple]: 268 return self.txn.fetchall() 269 270 def __iter__(self) -> Iterator[Tuple]: 271 return self.txn.__iter__() 272 273 @property 274 def rowcount(self) -> int: 275 return self.txn.rowcount 276 277 @property 278 def description(self) -> Any: 279 return self.txn.description 280 281 def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None: 282 """Similar to `executemany`, except `txn.rowcount` will not be correct 283 afterwards. 284 285 More efficient than `executemany` on PostgreSQL 286 """ 287 288 if isinstance(self.database_engine, PostgresEngine): 289 from psycopg2.extras import execute_batch # type: ignore 290 291 self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) 292 else: 293 self.executemany(sql, args) 294 295 def execute_values(self, sql: str, *args: Any, fetch: bool = True) -> List[Tuple]: 296 """Corresponds to psycopg2.extras.execute_values. Only available when 297 using postgres. 298 299 The `fetch` parameter must be set to False if the query does not return 300 rows (e.g. INSERTs). 301 """ 302 assert isinstance(self.database_engine, PostgresEngine) 303 from psycopg2.extras import execute_values # type: ignore 304 305 return self._do_execute( 306 lambda *x: execute_values(self.txn, *x, fetch=fetch), sql, *args 307 ) 308 309 def execute(self, sql: str, *args: Any) -> None: 310 self._do_execute(self.txn.execute, sql, *args) 311 312 def executemany(self, sql: str, *args: Any) -> None: 313 self._do_execute(self.txn.executemany, sql, *args) 314 315 def _make_sql_one_line(self, sql: str) -> str: 316 "Strip newlines out of SQL so that the loggers in the DB are on one line" 317 return " ".join(line.strip() for line in sql.splitlines() if line.strip()) 318 319 def _do_execute(self, func: Callable[..., R], sql: str, *args: Any) -> R: 320 sql = self._make_sql_one_line(sql) 321 322 # TODO(paul): Maybe use 'info' and 'debug' for values? 323 sql_logger.debug("[SQL] {%s} %s", self.name, sql) 324 325 sql = self.database_engine.convert_param_style(sql) 326 if args: 327 try: 328 sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) 329 except Exception: 330 # Don't let logging failures stop SQL from working 331 pass 332 333 start = time.time() 334 335 try: 336 with opentracing.start_active_span( 337 "db.query", 338 tags={ 339 opentracing.tags.DATABASE_TYPE: "sql", 340 opentracing.tags.DATABASE_STATEMENT: sql, 341 }, 342 ): 343 return func(sql, *args) 344 except Exception as e: 345 sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e) 346 raise 347 finally: 348 secs = time.time() - start 349 sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) 350 sql_query_timer.labels(sql.split()[0]).observe(secs) 351 352 def close(self) -> None: 353 self.txn.close() 354 355 def __enter__(self) -> "LoggingTransaction": 356 return self 357 358 def __exit__(self, exc_type, exc_value, traceback): 359 self.close() 360 361 362class PerformanceCounters: 363 def __init__(self): 364 self.current_counters = {} 365 self.previous_counters = {} 366 367 def update(self, key: str, duration_secs: float) -> None: 368 count, cum_time = self.current_counters.get(key, (0, 0)) 369 count += 1 370 cum_time += duration_secs 371 self.current_counters[key] = (count, cum_time) 372 373 def interval(self, interval_duration_secs: float, limit: int = 3) -> str: 374 counters = [] 375 for name, (count, cum_time) in self.current_counters.items(): 376 prev_count, prev_time = self.previous_counters.get(name, (0, 0)) 377 counters.append( 378 ( 379 (cum_time - prev_time) / interval_duration_secs, 380 count - prev_count, 381 name, 382 ) 383 ) 384 385 self.previous_counters = dict(self.current_counters) 386 387 counters.sort(reverse=True) 388 389 top_n_counters = ", ".join( 390 "%s(%d): %.3f%%" % (name, count, 100 * ratio) 391 for ratio, count, name in counters[:limit] 392 ) 393 394 return top_n_counters 395 396 397class DatabasePool: 398 """Wraps a single physical database and connection pool. 399 400 A single database may be used by multiple data stores. 401 """ 402 403 _TXN_ID = 0 404 405 def __init__( 406 self, 407 hs: "HomeServer", 408 database_config: DatabaseConnectionConfig, 409 engine: BaseDatabaseEngine, 410 ): 411 self.hs = hs 412 self._clock = hs.get_clock() 413 self._txn_limit = database_config.config.get("txn_limit", 0) 414 self._database_config = database_config 415 self._db_pool = make_pool(hs.get_reactor(), database_config, engine) 416 417 self.updates = BackgroundUpdater(hs, self) 418 419 self._previous_txn_total_time = 0.0 420 self._current_txn_total_time = 0.0 421 self._previous_loop_ts = 0.0 422 423 # Transaction counter: key is the twisted thread id, value is the current count 424 self._txn_counters: Dict[int, int] = defaultdict(int) 425 426 # TODO(paul): These can eventually be removed once the metrics code 427 # is running in mainline, and we have some nice monitoring frontends 428 # to watch it 429 self._txn_perf_counters = PerformanceCounters() 430 431 self.engine = engine 432 433 # A set of tables that are not safe to use native upserts in. 434 self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) 435 436 # We add the user_directory_search table to the blacklist on SQLite 437 # because the existing search table does not have an index, making it 438 # unsafe to use native upserts. 439 if isinstance(self.engine, Sqlite3Engine): 440 self._unsafe_to_upsert_tables.add("user_directory_search") 441 442 if self.engine.can_native_upsert: 443 # Check ASAP (and then later, every 1s) to see if we have finished 444 # background updates of tables that aren't safe to update. 445 self._clock.call_later( 446 0.0, 447 run_as_background_process, 448 "upsert_safety_check", 449 self._check_safe_to_upsert, 450 ) 451 452 def name(self) -> str: 453 "Return the name of this database" 454 return self._database_config.name 455 456 def is_running(self) -> bool: 457 """Is the database pool currently running""" 458 return self._db_pool.running 459 460 async def _check_safe_to_upsert(self) -> None: 461 """ 462 Is it safe to use native UPSERT? 463 464 If there are background updates, we will need to wait, as they may be 465 the addition of indexes that set the UNIQUE constraint that we require. 466 467 If the background updates have not completed, wait 15 sec and check again. 468 """ 469 updates = await self.simple_select_list( 470 "background_updates", 471 keyvalues=None, 472 retcols=["update_name"], 473 desc="check_background_updates", 474 ) 475 updates = [x["update_name"] for x in updates] 476 477 for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items(): 478 if update_name not in updates: 479 logger.debug("Now safe to upsert in %s", table) 480 self._unsafe_to_upsert_tables.discard(table) 481 482 # If there's any updates still running, reschedule to run. 483 if updates: 484 self._clock.call_later( 485 15.0, 486 run_as_background_process, 487 "upsert_safety_check", 488 self._check_safe_to_upsert, 489 ) 490 491 def start_profiling(self) -> None: 492 self._previous_loop_ts = monotonic_time() 493 494 def loop(): 495 curr = self._current_txn_total_time 496 prev = self._previous_txn_total_time 497 self._previous_txn_total_time = curr 498 499 time_now = monotonic_time() 500 time_then = self._previous_loop_ts 501 self._previous_loop_ts = time_now 502 503 duration = time_now - time_then 504 ratio = (curr - prev) / duration 505 506 top_three_counters = self._txn_perf_counters.interval(duration, limit=3) 507 508 perf_logger.debug( 509 "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters 510 ) 511 512 self._clock.looping_call(loop, 10000) 513 514 def new_transaction( 515 self, 516 conn: LoggingDatabaseConnection, 517 desc: str, 518 after_callbacks: List[_CallbackListEntry], 519 exception_callbacks: List[_CallbackListEntry], 520 func: Callable[..., R], 521 *args: Any, 522 **kwargs: Any, 523 ) -> R: 524 """Start a new database transaction with the given connection. 525 526 Note: The given func may be called multiple times under certain 527 failure modes. This is normally fine when in a standard transaction, 528 but care must be taken if the connection is in `autocommit` mode that 529 the function will correctly handle being aborted and retried half way 530 through its execution. 531 532 Similarly, the arguments to `func` (`args`, `kwargs`) should not be generators, 533 since they could be evaluated multiple times (which would produce an empty 534 result on the second or subsequent evaluation). Likewise, the closure of `func` 535 must not reference any generators. This method attempts to detect such usage 536 and will log an error. 537 538 Args: 539 conn 540 desc 541 after_callbacks 542 exception_callbacks 543 func 544 *args 545 **kwargs 546 """ 547 548 # Robustness check: ensure that none of the arguments are generators, since that 549 # will fail if we have to repeat the transaction. 550 # For now, we just log an error, and hope that it works on the first attempt. 551 # TODO: raise an exception. 552 for i, arg in enumerate(args): 553 if inspect.isgenerator(arg): 554 logger.error( 555 "Programming error: generator passed to new_transaction as " 556 "argument %i to function %s", 557 i, 558 func, 559 ) 560 for name, val in kwargs.items(): 561 if inspect.isgenerator(val): 562 logger.error( 563 "Programming error: generator passed to new_transaction as " 564 "argument %s to function %s", 565 name, 566 func, 567 ) 568 # also check variables referenced in func's closure 569 if inspect.isfunction(func): 570 f = cast(types.FunctionType, func) 571 if f.__closure__: 572 for i, cell in enumerate(f.__closure__): 573 if inspect.isgenerator(cell.cell_contents): 574 logger.error( 575 "Programming error: function %s references generator %s " 576 "via its closure", 577 f, 578 f.__code__.co_freevars[i], 579 ) 580 581 start = monotonic_time() 582 txn_id = self._TXN_ID 583 584 # We don't really need these to be unique, so lets stop it from 585 # growing really large. 586 self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID) 587 588 name = "%s-%x" % (desc, txn_id) 589 590 transaction_logger.debug("[TXN START] {%s}", name) 591 592 try: 593 i = 0 594 N = 5 595 while True: 596 cursor = conn.cursor( 597 txn_name=name, 598 after_callbacks=after_callbacks, 599 exception_callbacks=exception_callbacks, 600 ) 601 try: 602 with opentracing.start_active_span( 603 "db.txn", 604 tags={ 605 opentracing.SynapseTags.DB_TXN_DESC: desc, 606 opentracing.SynapseTags.DB_TXN_ID: name, 607 }, 608 ): 609 r = func(cursor, *args, **kwargs) 610 opentracing.log_kv({"message": "commit"}) 611 conn.commit() 612 return r 613 except self.engine.module.OperationalError as e: 614 # This can happen if the database disappears mid 615 # transaction. 616 transaction_logger.warning( 617 "[TXN OPERROR] {%s} %s %d/%d", 618 name, 619 e, 620 i, 621 N, 622 ) 623 if i < N: 624 i += 1 625 try: 626 with opentracing.start_active_span("db.rollback"): 627 conn.rollback() 628 except self.engine.module.Error as e1: 629 transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1) 630 continue 631 raise 632 except self.engine.module.DatabaseError as e: 633 if self.engine.is_deadlock(e): 634 transaction_logger.warning( 635 "[TXN DEADLOCK] {%s} %d/%d", name, i, N 636 ) 637 if i < N: 638 i += 1 639 try: 640 with opentracing.start_active_span("db.rollback"): 641 conn.rollback() 642 except self.engine.module.Error as e1: 643 transaction_logger.warning( 644 "[TXN EROLL] {%s} %s", 645 name, 646 e1, 647 ) 648 continue 649 raise 650 finally: 651 # we're either about to retry with a new cursor, or we're about to 652 # release the connection. Once we release the connection, it could 653 # get used for another query, which might do a conn.rollback(). 654 # 655 # In the latter case, even though that probably wouldn't affect the 656 # results of this transaction, python's sqlite will reset all 657 # statements on the connection [1], which will make our cursor 658 # invalid [2]. 659 # 660 # In any case, continuing to read rows after commit()ing seems 661 # dubious from the PoV of ACID transactional semantics 662 # (sqlite explicitly says that once you commit, you may see rows 663 # from subsequent updates.) 664 # 665 # In psycopg2, cursors are essentially a client-side fabrication - 666 # all the data is transferred to the client side when the statement 667 # finishes executing - so in theory we could go on streaming results 668 # from the cursor, but attempting to do so would make us 669 # incompatible with sqlite, so let's make sure we're not doing that 670 # by closing the cursor. 671 # 672 # (*named* cursors in psycopg2 are different and are proper server- 673 # side things, but (a) we don't use them and (b) they are implicitly 674 # closed by ending the transaction anyway.) 675 # 676 # In short, if we haven't finished with the cursor yet, that's a 677 # problem waiting to bite us. 678 # 679 # TL;DR: we're done with the cursor, so we can close it. 680 # 681 # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465 682 # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236 683 cursor.close() 684 except Exception as e: 685 transaction_logger.debug("[TXN FAIL] {%s} %s", name, e) 686 raise 687 finally: 688 end = monotonic_time() 689 duration = end - start 690 691 current_context().add_database_transaction(duration) 692 693 transaction_logger.debug("[TXN END] {%s} %f sec", name, duration) 694 695 self._current_txn_total_time += duration 696 self._txn_perf_counters.update(desc, duration) 697 sql_txn_timer.labels(desc).observe(duration) 698 699 async def runInteraction( 700 self, 701 desc: str, 702 func: Callable[..., R], 703 *args: Any, 704 db_autocommit: bool = False, 705 **kwargs: Any, 706 ) -> R: 707 """Starts a transaction on the database and runs a given function 708 709 Arguments: 710 desc: description of the transaction, for logging and metrics 711 func: callback function, which will be called with a 712 database transaction (twisted.enterprise.adbapi.Transaction) as 713 its first argument, followed by `args` and `kwargs`. 714 715 db_autocommit: Whether to run the function in "autocommit" mode, 716 i.e. outside of a transaction. This is useful for transactions 717 that are only a single query. 718 719 Currently, this is only implemented for Postgres. SQLite will still 720 run the function inside a transaction. 721 722 WARNING: This means that if func fails half way through then 723 the changes will *not* be rolled back. `func` may also get 724 called multiple times if the transaction is retried, so must 725 correctly handle that case. 726 727 args: positional args to pass to `func` 728 kwargs: named args to pass to `func` 729 730 Returns: 731 The result of func 732 """ 733 after_callbacks: List[_CallbackListEntry] = [] 734 exception_callbacks: List[_CallbackListEntry] = [] 735 736 if not current_context(): 737 logger.warning("Starting db txn '%s' from sentinel context", desc) 738 739 try: 740 with opentracing.start_active_span(f"db.{desc}"): 741 result = await self.runWithConnection( 742 self.new_transaction, 743 desc, 744 after_callbacks, 745 exception_callbacks, 746 func, 747 *args, 748 db_autocommit=db_autocommit, 749 **kwargs, 750 ) 751 752 for after_callback, after_args, after_kwargs in after_callbacks: 753 after_callback(*after_args, **after_kwargs) 754 except Exception: 755 for after_callback, after_args, after_kwargs in exception_callbacks: 756 after_callback(*after_args, **after_kwargs) 757 raise 758 759 return cast(R, result) 760 761 async def runWithConnection( 762 self, 763 func: Callable[..., R], 764 *args: Any, 765 db_autocommit: bool = False, 766 **kwargs: Any, 767 ) -> R: 768 """Wraps the .runWithConnection() method on the underlying db_pool. 769 770 Arguments: 771 func: callback function, which will be called with a 772 database connection (twisted.enterprise.adbapi.Connection) as 773 its first argument, followed by `args` and `kwargs`. 774 args: positional args to pass to `func` 775 db_autocommit: Whether to run the function in "autocommit" mode, 776 i.e. outside of a transaction. This is useful for transaction 777 that are only a single query. Currently only affects postgres. 778 kwargs: named args to pass to `func` 779 780 Returns: 781 The result of func 782 """ 783 curr_context = current_context() 784 if not curr_context: 785 logger.warning( 786 "Starting db connection from sentinel context: metrics will be lost" 787 ) 788 parent_context = None 789 else: 790 assert isinstance(curr_context, LoggingContext) 791 parent_context = curr_context 792 793 start_time = monotonic_time() 794 795 def inner_func(conn, *args, **kwargs): 796 # We shouldn't be in a transaction. If we are then something 797 # somewhere hasn't committed after doing work. (This is likely only 798 # possible during startup, as `run*` will ensure changes are 799 # committed/rolled back before putting the connection back in the 800 # pool). 801 assert not self.engine.in_transaction(conn) 802 803 with LoggingContext( 804 str(curr_context), parent_context=parent_context 805 ) as context: 806 with opentracing.start_active_span( 807 operation_name="db.connection", 808 ): 809 sched_duration_sec = monotonic_time() - start_time 810 sql_scheduling_timer.observe(sched_duration_sec) 811 context.add_database_scheduled(sched_duration_sec) 812 813 if self._txn_limit > 0: 814 tid = self._db_pool.threadID() 815 self._txn_counters[tid] += 1 816 817 if self._txn_counters[tid] > self._txn_limit: 818 logger.debug( 819 "Reconnecting database connection over transaction limit" 820 ) 821 conn.reconnect() 822 opentracing.log_kv( 823 {"message": "reconnected due to txn limit"} 824 ) 825 self._txn_counters[tid] = 1 826 827 if self.engine.is_connection_closed(conn): 828 logger.debug("Reconnecting closed database connection") 829 conn.reconnect() 830 opentracing.log_kv({"message": "reconnected"}) 831 if self._txn_limit > 0: 832 self._txn_counters[tid] = 1 833 834 try: 835 if db_autocommit: 836 self.engine.attempt_to_set_autocommit(conn, True) 837 838 db_conn = LoggingDatabaseConnection( 839 conn, self.engine, "runWithConnection" 840 ) 841 return func(db_conn, *args, **kwargs) 842 finally: 843 if db_autocommit: 844 self.engine.attempt_to_set_autocommit(conn, False) 845 846 return await make_deferred_yieldable( 847 self._db_pool.runWithConnection(inner_func, *args, **kwargs) 848 ) 849 850 @staticmethod 851 def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]: 852 """Converts a SQL cursor into an list of dicts. 853 854 Args: 855 cursor: The DBAPI cursor which has executed a query. 856 Returns: 857 A list of dicts where the key is the column header. 858 """ 859 assert cursor.description is not None, "cursor.description was None" 860 col_headers = [intern(str(column[0])) for column in cursor.description] 861 results = [dict(zip(col_headers, row)) for row in cursor] 862 return results 863 864 @overload 865 async def execute( 866 self, desc: str, decoder: Literal[None], query: str, *args: Any 867 ) -> List[Tuple[Any, ...]]: 868 ... 869 870 @overload 871 async def execute( 872 self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any 873 ) -> R: 874 ... 875 876 async def execute( 877 self, 878 desc: str, 879 decoder: Optional[Callable[[Cursor], R]], 880 query: str, 881 *args: Any, 882 ) -> R: 883 """Runs a single query for a result set. 884 885 Args: 886 desc: description of the transaction, for logging and metrics 887 decoder - The function which can resolve the cursor results to 888 something meaningful. 889 query - The query string to execute 890 *args - Query args. 891 Returns: 892 The result of decoder(results) 893 """ 894 895 def interaction(txn): 896 txn.execute(query, args) 897 if decoder: 898 return decoder(txn) 899 else: 900 return txn.fetchall() 901 902 return await self.runInteraction(desc, interaction) 903 904 # "Simple" SQL API methods that operate on a single table with no JOINs, 905 # no complex WHERE clauses, just a dict of values for columns. 906 907 async def simple_insert( 908 self, 909 table: str, 910 values: Dict[str, Any], 911 desc: str = "simple_insert", 912 ) -> None: 913 """Executes an INSERT query on the named table. 914 915 Args: 916 table: string giving the table name 917 values: dict of new column names and values for them 918 desc: description of the transaction, for logging and metrics 919 """ 920 await self.runInteraction(desc, self.simple_insert_txn, table, values) 921 922 @staticmethod 923 def simple_insert_txn( 924 txn: LoggingTransaction, table: str, values: Dict[str, Any] 925 ) -> None: 926 keys, vals = zip(*values.items()) 927 928 sql = "INSERT INTO %s (%s) VALUES(%s)" % ( 929 table, 930 ", ".join(k for k in keys), 931 ", ".join("?" for _ in keys), 932 ) 933 934 txn.execute(sql, vals) 935 936 async def simple_insert_many( 937 self, table: str, values: List[Dict[str, Any]], desc: str 938 ) -> None: 939 """Executes an INSERT query on the named table. 940 941 The input is given as a list of dicts, with one dict per row. 942 Generally simple_insert_many_values should be preferred for new code. 943 944 Args: 945 table: string giving the table name 946 values: dict of new column names and values for them 947 desc: description of the transaction, for logging and metrics 948 """ 949 await self.runInteraction(desc, self.simple_insert_many_txn, table, values) 950 951 @staticmethod 952 def simple_insert_many_txn( 953 txn: LoggingTransaction, table: str, values: List[Dict[str, Any]] 954 ) -> None: 955 """Executes an INSERT query on the named table. 956 957 The input is given as a list of dicts, with one dict per row. 958 Generally simple_insert_many_values_txn should be preferred for new code. 959 960 Args: 961 txn: The transaction to use. 962 table: string giving the table name 963 values: dict of new column names and values for them 964 """ 965 if not values: 966 return 967 968 # This is a *slight* abomination to get a list of tuples of key names 969 # and a list of tuples of value names. 970 # 971 # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}] 972 # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)] 973 # 974 # The sort is to ensure that we don't rely on dictionary iteration 975 # order. 976 keys, vals = zip( 977 *(zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i) 978 ) 979 980 for k in keys: 981 if k != keys[0]: 982 raise RuntimeError("All items must have the same keys") 983 984 return DatabasePool.simple_insert_many_values_txn(txn, table, keys[0], vals) 985 986 async def simple_insert_many_values( 987 self, 988 table: str, 989 keys: Collection[str], 990 values: Collection[Collection[Any]], 991 desc: str, 992 ) -> None: 993 """Executes an INSERT query on the named table. 994 995 The input is given as a list of rows, where each row is a list of values. 996 (Actually any iterable is fine.) 997 998 Args: 999 table: string giving the table name 1000 keys: list of column names 1001 values: for each row, a list of values in the same order as `keys` 1002 desc: description of the transaction, for logging and metrics 1003 """ 1004 await self.runInteraction( 1005 desc, self.simple_insert_many_values_txn, table, keys, values 1006 ) 1007 1008 @staticmethod 1009 def simple_insert_many_values_txn( 1010 txn: LoggingTransaction, 1011 table: str, 1012 keys: Collection[str], 1013 values: Iterable[Iterable[Any]], 1014 ) -> None: 1015 """Executes an INSERT query on the named table. 1016 1017 The input is given as a list of rows, where each row is a list of values. 1018 (Actually any iterable is fine.) 1019 1020 Args: 1021 txn: The transaction to use. 1022 table: string giving the table name 1023 keys: list of column names 1024 values: for each row, a list of values in the same order as `keys` 1025 """ 1026 1027 if isinstance(txn.database_engine, PostgresEngine): 1028 # We use `execute_values` as it can be a lot faster than `execute_batch`, 1029 # but it's only available on postgres. 1030 sql = "INSERT INTO %s (%s) VALUES ?" % ( 1031 table, 1032 ", ".join(k for k in keys), 1033 ) 1034 1035 txn.execute_values(sql, values, fetch=False) 1036 else: 1037 sql = "INSERT INTO %s (%s) VALUES(%s)" % ( 1038 table, 1039 ", ".join(k for k in keys), 1040 ", ".join("?" for _ in keys), 1041 ) 1042 1043 txn.execute_batch(sql, values) 1044 1045 async def simple_upsert( 1046 self, 1047 table: str, 1048 keyvalues: Dict[str, Any], 1049 values: Dict[str, Any], 1050 insertion_values: Optional[Dict[str, Any]] = None, 1051 desc: str = "simple_upsert", 1052 lock: bool = True, 1053 ) -> bool: 1054 """ 1055 1056 `lock` should generally be set to True (the default), but can be set 1057 to False if either of the following are true: 1058 1. there is a UNIQUE INDEX on the key columns. In this case a conflict 1059 will cause an IntegrityError in which case this function will retry 1060 the update. 1061 2. we somehow know that we are the only thread which will be updating 1062 this table. 1063 As an additional note, this parameter only matters for old SQLite versions 1064 because we will use native upserts otherwise. 1065 1066 Args: 1067 table: The table to upsert into 1068 keyvalues: The unique key columns and their new values 1069 values: The nonunique columns and their new values 1070 insertion_values: additional key/values to use only when inserting 1071 desc: description of the transaction, for logging and metrics 1072 lock: True to lock the table when doing the upsert. 1073 Returns: 1074 Returns True if a row was inserted or updated (i.e. if `values` is 1075 not empty then this always returns True) 1076 """ 1077 insertion_values = insertion_values or {} 1078 1079 attempts = 0 1080 while True: 1081 try: 1082 # We can autocommit if we are going to use native upserts 1083 autocommit = ( 1084 self.engine.can_native_upsert 1085 and table not in self._unsafe_to_upsert_tables 1086 ) 1087 1088 return await self.runInteraction( 1089 desc, 1090 self.simple_upsert_txn, 1091 table, 1092 keyvalues, 1093 values, 1094 insertion_values, 1095 lock=lock, 1096 db_autocommit=autocommit, 1097 ) 1098 except self.engine.module.IntegrityError as e: 1099 attempts += 1 1100 if attempts >= 5: 1101 # don't retry forever, because things other than races 1102 # can cause IntegrityErrors 1103 raise 1104 1105 # presumably we raced with another transaction: let's retry. 1106 logger.warning( 1107 "IntegrityError when upserting into %s; retrying: %s", table, e 1108 ) 1109 1110 def simple_upsert_txn( 1111 self, 1112 txn: LoggingTransaction, 1113 table: str, 1114 keyvalues: Dict[str, Any], 1115 values: Dict[str, Any], 1116 insertion_values: Optional[Dict[str, Any]] = None, 1117 lock: bool = True, 1118 ) -> bool: 1119 """ 1120 Pick the UPSERT method which works best on the platform. Either the 1121 native one (Pg9.5+, recent SQLites), or fall back to an emulated method. 1122 1123 Args: 1124 txn: The transaction to use. 1125 table: The table to upsert into 1126 keyvalues: The unique key tables and their new values 1127 values: The nonunique columns and their new values 1128 insertion_values: additional key/values to use only when inserting 1129 lock: True to lock the table when doing the upsert. 1130 Returns: 1131 Returns True if a row was inserted or updated (i.e. if `values` is 1132 not empty then this always returns True) 1133 """ 1134 insertion_values = insertion_values or {} 1135 1136 if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: 1137 return self.simple_upsert_txn_native_upsert( 1138 txn, table, keyvalues, values, insertion_values=insertion_values 1139 ) 1140 else: 1141 return self.simple_upsert_txn_emulated( 1142 txn, 1143 table, 1144 keyvalues, 1145 values, 1146 insertion_values=insertion_values, 1147 lock=lock, 1148 ) 1149 1150 def simple_upsert_txn_emulated( 1151 self, 1152 txn: LoggingTransaction, 1153 table: str, 1154 keyvalues: Dict[str, Any], 1155 values: Dict[str, Any], 1156 insertion_values: Optional[Dict[str, Any]] = None, 1157 lock: bool = True, 1158 ) -> bool: 1159 """ 1160 Args: 1161 table: The table to upsert into 1162 keyvalues: The unique key tables and their new values 1163 values: The nonunique columns and their new values 1164 insertion_values: additional key/values to use only when inserting 1165 lock: True to lock the table when doing the upsert. 1166 Returns: 1167 Returns True if a row was inserted or updated (i.e. if `values` is 1168 not empty then this always returns True) 1169 """ 1170 insertion_values = insertion_values or {} 1171 1172 # We need to lock the table :(, unless we're *really* careful 1173 if lock: 1174 self.engine.lock_table(txn, table) 1175 1176 def _getwhere(key): 1177 # If the value we're passing in is None (aka NULL), we need to use 1178 # IS, not =, as NULL = NULL equals NULL (False). 1179 if keyvalues[key] is None: 1180 return "%s IS ?" % (key,) 1181 else: 1182 return "%s = ?" % (key,) 1183 1184 if not values: 1185 # If `values` is empty, then all of the values we care about are in 1186 # the unique key, so there is nothing to UPDATE. We can just do a 1187 # SELECT instead to see if it exists. 1188 sql = "SELECT 1 FROM %s WHERE %s" % ( 1189 table, 1190 " AND ".join(_getwhere(k) for k in keyvalues), 1191 ) 1192 sqlargs = list(keyvalues.values()) 1193 txn.execute(sql, sqlargs) 1194 if txn.fetchall(): 1195 # We have an existing record. 1196 return False 1197 else: 1198 # First try to update. 1199 sql = "UPDATE %s SET %s WHERE %s" % ( 1200 table, 1201 ", ".join("%s = ?" % (k,) for k in values), 1202 " AND ".join(_getwhere(k) for k in keyvalues), 1203 ) 1204 sqlargs = list(values.values()) + list(keyvalues.values()) 1205 1206 txn.execute(sql, sqlargs) 1207 if txn.rowcount > 0: 1208 return True 1209 1210 # We didn't find any existing rows, so insert a new one 1211 allvalues: Dict[str, Any] = {} 1212 allvalues.update(keyvalues) 1213 allvalues.update(values) 1214 allvalues.update(insertion_values) 1215 1216 sql = "INSERT INTO %s (%s) VALUES (%s)" % ( 1217 table, 1218 ", ".join(k for k in allvalues), 1219 ", ".join("?" for _ in allvalues), 1220 ) 1221 txn.execute(sql, list(allvalues.values())) 1222 # successfully inserted 1223 return True 1224 1225 def simple_upsert_txn_native_upsert( 1226 self, 1227 txn: LoggingTransaction, 1228 table: str, 1229 keyvalues: Dict[str, Any], 1230 values: Dict[str, Any], 1231 insertion_values: Optional[Dict[str, Any]] = None, 1232 ) -> bool: 1233 """ 1234 Use the native UPSERT functionality in PostgreSQL. 1235 1236 Args: 1237 table: The table to upsert into 1238 keyvalues: The unique key tables and their new values 1239 values: The nonunique columns and their new values 1240 insertion_values: additional key/values to use only when inserting 1241 1242 Returns: 1243 Returns True if a row was inserted or updated (i.e. if `values` is 1244 not empty then this always returns True) 1245 """ 1246 allvalues: Dict[str, Any] = {} 1247 allvalues.update(keyvalues) 1248 allvalues.update(insertion_values or {}) 1249 1250 if not values: 1251 latter = "NOTHING" 1252 else: 1253 allvalues.update(values) 1254 latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values) 1255 1256 sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % ( 1257 table, 1258 ", ".join(k for k in allvalues), 1259 ", ".join("?" for _ in allvalues), 1260 ", ".join(k for k in keyvalues), 1261 latter, 1262 ) 1263 txn.execute(sql, list(allvalues.values())) 1264 1265 return bool(txn.rowcount) 1266 1267 async def simple_upsert_many( 1268 self, 1269 table: str, 1270 key_names: Collection[str], 1271 key_values: Collection[Collection[Any]], 1272 value_names: Collection[str], 1273 value_values: Collection[Collection[Any]], 1274 desc: str, 1275 ) -> None: 1276 """ 1277 Upsert, many times. 1278 1279 Args: 1280 table: The table to upsert into 1281 key_names: The key column names. 1282 key_values: A list of each row's key column values. 1283 value_names: The value column names 1284 value_values: A list of each row's value column values. 1285 Ignored if value_names is empty. 1286 """ 1287 1288 # We can autocommit if we are going to use native upserts 1289 autocommit = ( 1290 self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables 1291 ) 1292 1293 return await self.runInteraction( 1294 desc, 1295 self.simple_upsert_many_txn, 1296 table, 1297 key_names, 1298 key_values, 1299 value_names, 1300 value_values, 1301 db_autocommit=autocommit, 1302 ) 1303 1304 def simple_upsert_many_txn( 1305 self, 1306 txn: LoggingTransaction, 1307 table: str, 1308 key_names: Collection[str], 1309 key_values: Collection[Iterable[Any]], 1310 value_names: Collection[str], 1311 value_values: Iterable[Iterable[Any]], 1312 ) -> None: 1313 """ 1314 Upsert, many times. 1315 1316 Args: 1317 table: The table to upsert into 1318 key_names: The key column names. 1319 key_values: A list of each row's key column values. 1320 value_names: The value column names 1321 value_values: A list of each row's value column values. 1322 Ignored if value_names is empty. 1323 """ 1324 if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables: 1325 return self.simple_upsert_many_txn_native_upsert( 1326 txn, table, key_names, key_values, value_names, value_values 1327 ) 1328 else: 1329 return self.simple_upsert_many_txn_emulated( 1330 txn, table, key_names, key_values, value_names, value_values 1331 ) 1332 1333 def simple_upsert_many_txn_emulated( 1334 self, 1335 txn: LoggingTransaction, 1336 table: str, 1337 key_names: Iterable[str], 1338 key_values: Collection[Iterable[Any]], 1339 value_names: Collection[str], 1340 value_values: Iterable[Iterable[Any]], 1341 ) -> None: 1342 """ 1343 Upsert, many times, but without native UPSERT support or batching. 1344 1345 Args: 1346 table: The table to upsert into 1347 key_names: The key column names. 1348 key_values: A list of each row's key column values. 1349 value_names: The value column names 1350 value_values: A list of each row's value column values. 1351 Ignored if value_names is empty. 1352 """ 1353 # No value columns, therefore make a blank list so that the following 1354 # zip() works correctly. 1355 if not value_names: 1356 value_values = [() for x in range(len(key_values))] 1357 1358 for keyv, valv in zip(key_values, value_values): 1359 _keys = {x: y for x, y in zip(key_names, keyv)} 1360 _vals = {x: y for x, y in zip(value_names, valv)} 1361 1362 self.simple_upsert_txn_emulated(txn, table, _keys, _vals) 1363 1364 def simple_upsert_many_txn_native_upsert( 1365 self, 1366 txn: LoggingTransaction, 1367 table: str, 1368 key_names: Collection[str], 1369 key_values: Collection[Iterable[Any]], 1370 value_names: Collection[str], 1371 value_values: Iterable[Iterable[Any]], 1372 ) -> None: 1373 """ 1374 Upsert, many times, using batching where possible. 1375 1376 Args: 1377 table: The table to upsert into 1378 key_names: The key column names. 1379 key_values: A list of each row's key column values. 1380 value_names: The value column names 1381 value_values: A list of each row's value column values. 1382 Ignored if value_names is empty. 1383 """ 1384 allnames: List[str] = [] 1385 allnames.extend(key_names) 1386 allnames.extend(value_names) 1387 1388 if not value_names: 1389 # No value columns, therefore make a blank list so that the 1390 # following zip() works correctly. 1391 latter = "NOTHING" 1392 value_values = [() for x in range(len(key_values))] 1393 else: 1394 latter = "UPDATE SET " + ", ".join( 1395 k + "=EXCLUDED." + k for k in value_names 1396 ) 1397 1398 args = [] 1399 1400 for x, y in zip(key_values, value_values): 1401 args.append(tuple(x) + tuple(y)) 1402 1403 if isinstance(txn.database_engine, PostgresEngine): 1404 # We use `execute_values` as it can be a lot faster than `execute_batch`, 1405 # but it's only available on postgres. 1406 sql = "INSERT INTO %s (%s) VALUES ? ON CONFLICT (%s) DO %s" % ( 1407 table, 1408 ", ".join(k for k in allnames), 1409 ", ".join(key_names), 1410 latter, 1411 ) 1412 1413 txn.execute_values(sql, args, fetch=False) 1414 1415 else: 1416 sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % ( 1417 table, 1418 ", ".join(k for k in allnames), 1419 ", ".join("?" for _ in allnames), 1420 ", ".join(key_names), 1421 latter, 1422 ) 1423 1424 return txn.execute_batch(sql, args) 1425 1426 @overload 1427 async def simple_select_one( 1428 self, 1429 table: str, 1430 keyvalues: Dict[str, Any], 1431 retcols: Collection[str], 1432 allow_none: Literal[False] = False, 1433 desc: str = "simple_select_one", 1434 ) -> Dict[str, Any]: 1435 ... 1436 1437 @overload 1438 async def simple_select_one( 1439 self, 1440 table: str, 1441 keyvalues: Dict[str, Any], 1442 retcols: Collection[str], 1443 allow_none: Literal[True] = True, 1444 desc: str = "simple_select_one", 1445 ) -> Optional[Dict[str, Any]]: 1446 ... 1447 1448 async def simple_select_one( 1449 self, 1450 table: str, 1451 keyvalues: Dict[str, Any], 1452 retcols: Collection[str], 1453 allow_none: bool = False, 1454 desc: str = "simple_select_one", 1455 ) -> Optional[Dict[str, Any]]: 1456 """Executes a SELECT query on the named table, which is expected to 1457 return a single row, returning multiple columns from it. 1458 1459 Args: 1460 table: string giving the table name 1461 keyvalues: dict of column names and values to select the row with 1462 retcols: list of strings giving the names of the columns to return 1463 allow_none: If true, return None instead of failing if the SELECT 1464 statement returns no rows 1465 desc: description of the transaction, for logging and metrics 1466 """ 1467 return await self.runInteraction( 1468 desc, 1469 self.simple_select_one_txn, 1470 table, 1471 keyvalues, 1472 retcols, 1473 allow_none, 1474 db_autocommit=True, 1475 ) 1476 1477 @overload 1478 async def simple_select_one_onecol( 1479 self, 1480 table: str, 1481 keyvalues: Dict[str, Any], 1482 retcol: str, 1483 allow_none: Literal[False] = False, 1484 desc: str = "simple_select_one_onecol", 1485 ) -> Any: 1486 ... 1487 1488 @overload 1489 async def simple_select_one_onecol( 1490 self, 1491 table: str, 1492 keyvalues: Dict[str, Any], 1493 retcol: str, 1494 allow_none: Literal[True] = True, 1495 desc: str = "simple_select_one_onecol", 1496 ) -> Optional[Any]: 1497 ... 1498 1499 async def simple_select_one_onecol( 1500 self, 1501 table: str, 1502 keyvalues: Dict[str, Any], 1503 retcol: str, 1504 allow_none: bool = False, 1505 desc: str = "simple_select_one_onecol", 1506 ) -> Optional[Any]: 1507 """Executes a SELECT query on the named table, which is expected to 1508 return a single row, returning a single column from it. 1509 1510 Args: 1511 table: string giving the table name 1512 keyvalues: dict of column names and values to select the row with 1513 retcol: string giving the name of the column to return 1514 allow_none: If true, return None instead of failing if the SELECT 1515 statement returns no rows 1516 desc: description of the transaction, for logging and metrics 1517 """ 1518 return await self.runInteraction( 1519 desc, 1520 self.simple_select_one_onecol_txn, 1521 table, 1522 keyvalues, 1523 retcol, 1524 allow_none=allow_none, 1525 db_autocommit=True, 1526 ) 1527 1528 @overload 1529 @classmethod 1530 def simple_select_one_onecol_txn( 1531 cls, 1532 txn: LoggingTransaction, 1533 table: str, 1534 keyvalues: Dict[str, Any], 1535 retcol: str, 1536 allow_none: Literal[False] = False, 1537 ) -> Any: 1538 ... 1539 1540 @overload 1541 @classmethod 1542 def simple_select_one_onecol_txn( 1543 cls, 1544 txn: LoggingTransaction, 1545 table: str, 1546 keyvalues: Dict[str, Any], 1547 retcol: str, 1548 allow_none: Literal[True] = True, 1549 ) -> Optional[Any]: 1550 ... 1551 1552 @classmethod 1553 def simple_select_one_onecol_txn( 1554 cls, 1555 txn: LoggingTransaction, 1556 table: str, 1557 keyvalues: Dict[str, Any], 1558 retcol: str, 1559 allow_none: bool = False, 1560 ) -> Optional[Any]: 1561 ret = cls.simple_select_onecol_txn( 1562 txn, table=table, keyvalues=keyvalues, retcol=retcol 1563 ) 1564 1565 if ret: 1566 return ret[0] 1567 else: 1568 if allow_none: 1569 return None 1570 else: 1571 raise StoreError(404, "No row found") 1572 1573 @staticmethod 1574 def simple_select_onecol_txn( 1575 txn: LoggingTransaction, 1576 table: str, 1577 keyvalues: Dict[str, Any], 1578 retcol: str, 1579 ) -> List[Any]: 1580 sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} 1581 1582 if keyvalues: 1583 sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) 1584 txn.execute(sql, list(keyvalues.values())) 1585 else: 1586 txn.execute(sql) 1587 1588 return [r[0] for r in txn] 1589 1590 async def simple_select_onecol( 1591 self, 1592 table: str, 1593 keyvalues: Optional[Dict[str, Any]], 1594 retcol: str, 1595 desc: str = "simple_select_onecol", 1596 ) -> List[Any]: 1597 """Executes a SELECT query on the named table, which returns a list 1598 comprising of the values of the named column from the selected rows. 1599 1600 Args: 1601 table: table name 1602 keyvalues: column names and values to select the rows with 1603 retcol: column whos value we wish to retrieve. 1604 desc: description of the transaction, for logging and metrics 1605 1606 Returns: 1607 Results in a list 1608 """ 1609 return await self.runInteraction( 1610 desc, 1611 self.simple_select_onecol_txn, 1612 table, 1613 keyvalues, 1614 retcol, 1615 db_autocommit=True, 1616 ) 1617 1618 async def simple_select_list( 1619 self, 1620 table: str, 1621 keyvalues: Optional[Dict[str, Any]], 1622 retcols: Collection[str], 1623 desc: str = "simple_select_list", 1624 ) -> List[Dict[str, Any]]: 1625 """Executes a SELECT query on the named table, which may return zero or 1626 more rows, returning the result as a list of dicts. 1627 1628 Args: 1629 table: the table name 1630 keyvalues: 1631 column names and values to select the rows with, or None to not 1632 apply a WHERE clause. 1633 retcols: the names of the columns to return 1634 desc: description of the transaction, for logging and metrics 1635 1636 Returns: 1637 A list of dictionaries. 1638 """ 1639 return await self.runInteraction( 1640 desc, 1641 self.simple_select_list_txn, 1642 table, 1643 keyvalues, 1644 retcols, 1645 db_autocommit=True, 1646 ) 1647 1648 @classmethod 1649 def simple_select_list_txn( 1650 cls, 1651 txn: LoggingTransaction, 1652 table: str, 1653 keyvalues: Optional[Dict[str, Any]], 1654 retcols: Iterable[str], 1655 ) -> List[Dict[str, Any]]: 1656 """Executes a SELECT query on the named table, which may return zero or 1657 more rows, returning the result as a list of dicts. 1658 1659 Args: 1660 txn: Transaction object 1661 table: the table name 1662 keyvalues: 1663 column names and values to select the rows with, or None to not 1664 apply a WHERE clause. 1665 retcols: the names of the columns to return 1666 """ 1667 if keyvalues: 1668 sql = "SELECT %s FROM %s WHERE %s" % ( 1669 ", ".join(retcols), 1670 table, 1671 " AND ".join("%s = ?" % (k,) for k in keyvalues), 1672 ) 1673 txn.execute(sql, list(keyvalues.values())) 1674 else: 1675 sql = "SELECT %s FROM %s" % (", ".join(retcols), table) 1676 txn.execute(sql) 1677 1678 return cls.cursor_to_dict(txn) 1679 1680 async def simple_select_many_batch( 1681 self, 1682 table: str, 1683 column: str, 1684 iterable: Iterable[Any], 1685 retcols: Collection[str], 1686 keyvalues: Optional[Dict[str, Any]] = None, 1687 desc: str = "simple_select_many_batch", 1688 batch_size: int = 100, 1689 ) -> List[Any]: 1690 """Executes a SELECT query on the named table, which may return zero or 1691 more rows, returning the result as a list of dicts. 1692 1693 Filters rows by whether the value of `column` is in `iterable`. 1694 1695 Args: 1696 table: string giving the table name 1697 column: column name to test for inclusion against `iterable` 1698 iterable: list 1699 retcols: list of strings giving the names of the columns to return 1700 keyvalues: dict of column names and values to select the rows with 1701 desc: description of the transaction, for logging and metrics 1702 batch_size: the number of rows for each select query 1703 """ 1704 keyvalues = keyvalues or {} 1705 1706 results: List[Dict[str, Any]] = [] 1707 1708 for chunk in batch_iter(iterable, batch_size): 1709 rows = await self.runInteraction( 1710 desc, 1711 self.simple_select_many_txn, 1712 table, 1713 column, 1714 chunk, 1715 keyvalues, 1716 retcols, 1717 db_autocommit=True, 1718 ) 1719 1720 results.extend(rows) 1721 1722 return results 1723 1724 @classmethod 1725 def simple_select_many_txn( 1726 cls, 1727 txn: LoggingTransaction, 1728 table: str, 1729 column: str, 1730 iterable: Collection[Any], 1731 keyvalues: Dict[str, Any], 1732 retcols: Iterable[str], 1733 ) -> List[Dict[str, Any]]: 1734 """Executes a SELECT query on the named table, which may return zero or 1735 more rows, returning the result as a list of dicts. 1736 1737 Filters rows by whether the value of `column` is in `iterable`. 1738 1739 Args: 1740 txn: Transaction object 1741 table: string giving the table name 1742 column: column name to test for inclusion against `iterable` 1743 iterable: list 1744 keyvalues: dict of column names and values to select the rows with 1745 retcols: list of strings giving the names of the columns to return 1746 """ 1747 if not iterable: 1748 return [] 1749 1750 clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) 1751 clauses = [clause] 1752 1753 for key, value in keyvalues.items(): 1754 clauses.append("%s = ?" % (key,)) 1755 values.append(value) 1756 1757 sql = "SELECT %s FROM %s WHERE %s" % ( 1758 ", ".join(retcols), 1759 table, 1760 " AND ".join(clauses), 1761 ) 1762 1763 txn.execute(sql, values) 1764 return cls.cursor_to_dict(txn) 1765 1766 async def simple_update( 1767 self, 1768 table: str, 1769 keyvalues: Dict[str, Any], 1770 updatevalues: Dict[str, Any], 1771 desc: str, 1772 ) -> int: 1773 return await self.runInteraction( 1774 desc, self.simple_update_txn, table, keyvalues, updatevalues 1775 ) 1776 1777 @staticmethod 1778 def simple_update_txn( 1779 txn: LoggingTransaction, 1780 table: str, 1781 keyvalues: Dict[str, Any], 1782 updatevalues: Dict[str, Any], 1783 ) -> int: 1784 if keyvalues: 1785 where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) 1786 else: 1787 where = "" 1788 1789 update_sql = "UPDATE %s SET %s %s" % ( 1790 table, 1791 ", ".join("%s = ?" % (k,) for k in updatevalues), 1792 where, 1793 ) 1794 1795 txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values())) 1796 1797 return txn.rowcount 1798 1799 async def simple_update_one( 1800 self, 1801 table: str, 1802 keyvalues: Dict[str, Any], 1803 updatevalues: Dict[str, Any], 1804 desc: str = "simple_update_one", 1805 ) -> None: 1806 """Executes an UPDATE query on the named table, setting new values for 1807 columns in a row matching the key values. 1808 1809 Args: 1810 table: string giving the table name 1811 keyvalues: dict of column names and values to select the row with 1812 updatevalues: dict giving column names and values to update 1813 desc: description of the transaction, for logging and metrics 1814 """ 1815 await self.runInteraction( 1816 desc, 1817 self.simple_update_one_txn, 1818 table, 1819 keyvalues, 1820 updatevalues, 1821 db_autocommit=True, 1822 ) 1823 1824 @classmethod 1825 def simple_update_one_txn( 1826 cls, 1827 txn: LoggingTransaction, 1828 table: str, 1829 keyvalues: Dict[str, Any], 1830 updatevalues: Dict[str, Any], 1831 ) -> None: 1832 rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues) 1833 1834 if rowcount == 0: 1835 raise StoreError(404, "No row found (%s)" % (table,)) 1836 if rowcount > 1: 1837 raise StoreError(500, "More than one row matched (%s)" % (table,)) 1838 1839 # Ideally we could use the overload decorator here to specify that the 1840 # return type is only optional if allow_none is True, but this does not work 1841 # when you call a static method from an instance. 1842 # See https://github.com/python/mypy/issues/7781 1843 @staticmethod 1844 def simple_select_one_txn( 1845 txn: LoggingTransaction, 1846 table: str, 1847 keyvalues: Dict[str, Any], 1848 retcols: Collection[str], 1849 allow_none: bool = False, 1850 ) -> Optional[Dict[str, Any]]: 1851 select_sql = "SELECT %s FROM %s WHERE %s" % ( 1852 ", ".join(retcols), 1853 table, 1854 " AND ".join("%s = ?" % (k,) for k in keyvalues), 1855 ) 1856 1857 txn.execute(select_sql, list(keyvalues.values())) 1858 row = txn.fetchone() 1859 1860 if not row: 1861 if allow_none: 1862 return None 1863 raise StoreError(404, "No row found (%s)" % (table,)) 1864 if txn.rowcount > 1: 1865 raise StoreError(500, "More than one row matched (%s)" % (table,)) 1866 1867 return dict(zip(retcols, row)) 1868 1869 async def simple_delete_one( 1870 self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one" 1871 ) -> None: 1872 """Executes a DELETE query on the named table, expecting to delete a 1873 single row. 1874 1875 Args: 1876 table: string giving the table name 1877 keyvalues: dict of column names and values to select the row with 1878 desc: description of the transaction, for logging and metrics 1879 """ 1880 await self.runInteraction( 1881 desc, 1882 self.simple_delete_one_txn, 1883 table, 1884 keyvalues, 1885 db_autocommit=True, 1886 ) 1887 1888 @staticmethod 1889 def simple_delete_one_txn( 1890 txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any] 1891 ) -> None: 1892 """Executes a DELETE query on the named table, expecting to delete a 1893 single row. 1894 1895 Args: 1896 table: string giving the table name 1897 keyvalues: dict of column names and values to select the row with 1898 """ 1899 sql = "DELETE FROM %s WHERE %s" % ( 1900 table, 1901 " AND ".join("%s = ?" % (k,) for k in keyvalues), 1902 ) 1903 1904 txn.execute(sql, list(keyvalues.values())) 1905 if txn.rowcount == 0: 1906 raise StoreError(404, "No row found (%s)" % (table,)) 1907 if txn.rowcount > 1: 1908 raise StoreError(500, "More than one row matched (%s)" % (table,)) 1909 1910 async def simple_delete( 1911 self, table: str, keyvalues: Dict[str, Any], desc: str 1912 ) -> int: 1913 """Executes a DELETE query on the named table. 1914 1915 Filters rows by the key-value pairs. 1916 1917 Args: 1918 table: string giving the table name 1919 keyvalues: dict of column names and values to select the row with 1920 desc: description of the transaction, for logging and metrics 1921 1922 Returns: 1923 The number of deleted rows. 1924 """ 1925 return await self.runInteraction( 1926 desc, self.simple_delete_txn, table, keyvalues, db_autocommit=True 1927 ) 1928 1929 @staticmethod 1930 def simple_delete_txn( 1931 txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any] 1932 ) -> int: 1933 """Executes a DELETE query on the named table. 1934 1935 Filters rows by the key-value pairs. 1936 1937 Args: 1938 table: string giving the table name 1939 keyvalues: dict of column names and values to select the row with 1940 1941 Returns: 1942 The number of deleted rows. 1943 """ 1944 sql = "DELETE FROM %s WHERE %s" % ( 1945 table, 1946 " AND ".join("%s = ?" % (k,) for k in keyvalues), 1947 ) 1948 1949 txn.execute(sql, list(keyvalues.values())) 1950 return txn.rowcount 1951 1952 async def simple_delete_many( 1953 self, 1954 table: str, 1955 column: str, 1956 iterable: Collection[Any], 1957 keyvalues: Dict[str, Any], 1958 desc: str, 1959 ) -> int: 1960 """Executes a DELETE query on the named table. 1961 1962 Filters rows by if value of `column` is in `iterable`. 1963 1964 Args: 1965 table: string giving the table name 1966 column: column name to test for inclusion against `iterable` 1967 iterable: list of values to match against `column`. NB cannot be a generator 1968 as it may be evaluated multiple times. 1969 keyvalues: dict of column names and values to select the rows with 1970 desc: description of the transaction, for logging and metrics 1971 1972 Returns: 1973 Number rows deleted 1974 """ 1975 return await self.runInteraction( 1976 desc, 1977 self.simple_delete_many_txn, 1978 table, 1979 column, 1980 iterable, 1981 keyvalues, 1982 db_autocommit=True, 1983 ) 1984 1985 @staticmethod 1986 def simple_delete_many_txn( 1987 txn: LoggingTransaction, 1988 table: str, 1989 column: str, 1990 values: Collection[Any], 1991 keyvalues: Dict[str, Any], 1992 ) -> int: 1993 """Executes a DELETE query on the named table. 1994 1995 Deletes the rows: 1996 - whose value of `column` is in `values`; AND 1997 - that match extra column-value pairs specified in `keyvalues`. 1998 1999 Args: 2000 txn: Transaction object 2001 table: string giving the table name 2002 column: column name to test for inclusion against `values` 2003 values: values of `column` which choose rows to delete 2004 keyvalues: dict of extra column names and values to select the rows 2005 with. They will be ANDed together with the main predicate. 2006 2007 Returns: 2008 Number rows deleted 2009 """ 2010 if not values: 2011 return 0 2012 2013 sql = "DELETE FROM %s" % table 2014 2015 clause, values = make_in_list_sql_clause(txn.database_engine, column, values) 2016 clauses = [clause] 2017 2018 for key, value in keyvalues.items(): 2019 clauses.append("%s = ?" % (key,)) 2020 values.append(value) 2021 2022 if clauses: 2023 sql = "%s WHERE %s" % (sql, " AND ".join(clauses)) 2024 txn.execute(sql, values) 2025 2026 return txn.rowcount 2027 2028 def get_cache_dict( 2029 self, 2030 db_conn: LoggingDatabaseConnection, 2031 table: str, 2032 entity_column: str, 2033 stream_column: str, 2034 max_value: int, 2035 limit: int = 100000, 2036 ) -> Tuple[Dict[Any, int], int]: 2037 # Fetch a mapping of room_id -> max stream position for "recent" rooms. 2038 # It doesn't really matter how many we get, the StreamChangeCache will 2039 # do the right thing to ensure it respects the max size of cache. 2040 sql = ( 2041 "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s" 2042 " WHERE %(stream)s > ? - %(limit)s" 2043 " GROUP BY %(entity)s" 2044 ) % { 2045 "table": table, 2046 "entity": entity_column, 2047 "stream": stream_column, 2048 "limit": limit, 2049 } 2050 2051 txn = db_conn.cursor(txn_name="get_cache_dict") 2052 txn.execute(sql, (int(max_value),)) 2053 2054 cache = {row[0]: int(row[1]) for row in txn} 2055 2056 txn.close() 2057 2058 if cache: 2059 min_val = min(cache.values()) 2060 else: 2061 min_val = max_value 2062 2063 return cache, min_val 2064 2065 @classmethod 2066 def simple_select_list_paginate_txn( 2067 cls, 2068 txn: LoggingTransaction, 2069 table: str, 2070 orderby: str, 2071 start: int, 2072 limit: int, 2073 retcols: Iterable[str], 2074 filters: Optional[Dict[str, Any]] = None, 2075 keyvalues: Optional[Dict[str, Any]] = None, 2076 exclude_keyvalues: Optional[Dict[str, Any]] = None, 2077 order_direction: str = "ASC", 2078 ) -> List[Dict[str, Any]]: 2079 """ 2080 Executes a SELECT query on the named table with start and limit, 2081 of row numbers, which may return zero or number of rows from start to limit, 2082 returning the result as a list of dicts. 2083 2084 Use `filters` to search attributes using SQL wildcards and/or `keyvalues` to 2085 select attributes with exact matches. All constraints are joined together 2086 using 'AND'. 2087 2088 Args: 2089 txn: Transaction object 2090 table: the table name 2091 orderby: Column to order the results by. 2092 start: Index to begin the query at. 2093 limit: Number of results to return. 2094 retcols: the names of the columns to return 2095 filters: 2096 column names and values to filter the rows with, or None to not 2097 apply a WHERE ? LIKE ? clause. 2098 keyvalues: 2099 column names and values to select the rows with, or None to not 2100 apply a WHERE key = value clause. 2101 exclude_keyvalues: 2102 column names and values to exclude rows with, or None to not 2103 apply a WHERE key != value clause. 2104 order_direction: Whether the results should be ordered "ASC" or "DESC". 2105 2106 Returns: 2107 The result as a list of dictionaries. 2108 """ 2109 if order_direction not in ["ASC", "DESC"]: 2110 raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") 2111 2112 where_clause = "WHERE " if filters or keyvalues or exclude_keyvalues else "" 2113 arg_list: List[Any] = [] 2114 if filters: 2115 where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters) 2116 arg_list += list(filters.values()) 2117 where_clause += " AND " if filters and keyvalues else "" 2118 if keyvalues: 2119 where_clause += " AND ".join("%s = ?" % (k,) for k in keyvalues) 2120 arg_list += list(keyvalues.values()) 2121 if exclude_keyvalues: 2122 where_clause += " AND ".join("%s != ?" % (k,) for k in exclude_keyvalues) 2123 arg_list += list(exclude_keyvalues.values()) 2124 2125 sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % ( 2126 ", ".join(retcols), 2127 table, 2128 where_clause, 2129 orderby, 2130 order_direction, 2131 ) 2132 txn.execute(sql, arg_list + [limit, start]) 2133 2134 return cls.cursor_to_dict(txn) 2135 2136 async def simple_search_list( 2137 self, 2138 table: str, 2139 term: Optional[str], 2140 col: str, 2141 retcols: Collection[str], 2142 desc="simple_search_list", 2143 ) -> Optional[List[Dict[str, Any]]]: 2144 """Executes a SELECT query on the named table, which may return zero or 2145 more rows, returning the result as a list of dicts. 2146 2147 Args: 2148 table: the table name 2149 term: term for searching the table matched to a column. 2150 col: column to query term should be matched to 2151 retcols: the names of the columns to return 2152 2153 Returns: 2154 A list of dictionaries or None. 2155 """ 2156 2157 return await self.runInteraction( 2158 desc, 2159 self.simple_search_list_txn, 2160 table, 2161 term, 2162 col, 2163 retcols, 2164 db_autocommit=True, 2165 ) 2166 2167 @classmethod 2168 def simple_search_list_txn( 2169 cls, 2170 txn: LoggingTransaction, 2171 table: str, 2172 term: Optional[str], 2173 col: str, 2174 retcols: Iterable[str], 2175 ) -> Optional[List[Dict[str, Any]]]: 2176 """Executes a SELECT query on the named table, which may return zero or 2177 more rows, returning the result as a list of dicts. 2178 2179 Args: 2180 txn: Transaction object 2181 table: the table name 2182 term: term for searching the table matched to a column. 2183 col: column to query term should be matched to 2184 retcols: the names of the columns to return 2185 2186 Returns: 2187 None if no term is given, otherwise a list of dictionaries. 2188 """ 2189 if term: 2190 sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col) 2191 termvalues = ["%%" + term + "%%"] 2192 txn.execute(sql, termvalues) 2193 else: 2194 return None 2195 2196 return cls.cursor_to_dict(txn) 2197 2198 2199def make_in_list_sql_clause( 2200 database_engine: BaseDatabaseEngine, column: str, iterable: Collection[Any] 2201) -> Tuple[str, list]: 2202 """Returns an SQL clause that checks the given column is in the iterable. 2203 2204 On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres 2205 it expands to `column = ANY(?)`. While both DBs support the `IN` form, 2206 using the `ANY` form on postgres means that it views queries with 2207 different length iterables as the same, helping the query stats. 2208 2209 Args: 2210 database_engine 2211 column: Name of the column 2212 iterable: The values to check the column against. 2213 2214 Returns: 2215 A tuple of SQL query and the args 2216 """ 2217 2218 if database_engine.supports_using_any_list: 2219 # This should hopefully be faster, but also makes postgres query 2220 # stats easier to understand. 2221 return "%s = ANY(?)" % (column,), [list(iterable)] 2222 else: 2223 return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable) 2224 2225 2226KV = TypeVar("KV") 2227 2228 2229def make_tuple_comparison_clause(keys: List[Tuple[str, KV]]) -> Tuple[str, List[KV]]: 2230 """Returns a tuple comparison SQL clause 2231 2232 Builds a SQL clause that looks like "(a, b) > (?, ?)" 2233 2234 Args: 2235 keys: A set of (column, value) pairs to be compared. 2236 2237 Returns: 2238 A tuple of SQL query and the args 2239 """ 2240 return ( 2241 "(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)), 2242 [k[1] for k in keys], 2243 ) 2244