1# Copyright 2014-2016 OpenMarket Ltd 2# Copyright 2021 The Matrix.org Foundation C.I.C. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15import abc 16import heapq 17import logging 18import threading 19from collections import OrderedDict 20from contextlib import contextmanager 21from types import TracebackType 22from typing import ( 23 AsyncContextManager, 24 ContextManager, 25 Dict, 26 Generator, 27 Generic, 28 Iterable, 29 List, 30 Optional, 31 Sequence, 32 Set, 33 Tuple, 34 Type, 35 TypeVar, 36 Union, 37 cast, 38) 39 40import attr 41from sortedcontainers import SortedList, SortedSet 42 43from synapse.metrics.background_process_metrics import run_as_background_process 44from synapse.storage.database import ( 45 DatabasePool, 46 LoggingDatabaseConnection, 47 LoggingTransaction, 48) 49from synapse.storage.types import Cursor 50from synapse.storage.util.sequence import PostgresSequenceGenerator 51 52logger = logging.getLogger(__name__) 53 54 55T = TypeVar("T") 56 57 58class IdGenerator: 59 def __init__( 60 self, 61 db_conn: LoggingDatabaseConnection, 62 table: str, 63 column: str, 64 ): 65 self._lock = threading.Lock() 66 self._next_id = _load_current_id(db_conn, table, column) 67 68 def get_next(self) -> int: 69 with self._lock: 70 self._next_id += 1 71 return self._next_id 72 73 74def _load_current_id( 75 db_conn: LoggingDatabaseConnection, table: str, column: str, step: int = 1 76) -> int: 77 cur = db_conn.cursor(txn_name="_load_current_id") 78 if step == 1: 79 cur.execute("SELECT MAX(%s) FROM %s" % (column, table)) 80 else: 81 cur.execute("SELECT MIN(%s) FROM %s" % (column, table)) 82 result = cur.fetchone() 83 assert result is not None 84 (val,) = result 85 cur.close() 86 current_id = int(val) if val else step 87 res = (max if step > 0 else min)(current_id, step) 88 logger.info("Initialising stream generator for %s(%s): %i", table, column, res) 89 return res 90 91 92class AbstractStreamIdTracker(metaclass=abc.ABCMeta): 93 """Tracks the "current" stream ID of a stream that may have multiple writers. 94 95 Stream IDs are monotonically increasing or decreasing integers representing write 96 transactions. The "current" stream ID is the stream ID such that all transactions 97 with equal or smaller stream IDs have completed. Since transactions may complete out 98 of order, this is not the same as the stream ID of the last completed transaction. 99 100 Completed transactions include both committed transactions and transactions that 101 have been rolled back. 102 """ 103 104 @abc.abstractmethod 105 def advance(self, instance_name: str, new_id: int) -> None: 106 """Advance the position of the named writer to the given ID, if greater 107 than existing entry. 108 """ 109 raise NotImplementedError() 110 111 @abc.abstractmethod 112 def get_current_token(self) -> int: 113 """Returns the maximum stream id such that all stream ids less than or 114 equal to it have been successfully persisted. 115 116 Returns: 117 The maximum stream id. 118 """ 119 raise NotImplementedError() 120 121 @abc.abstractmethod 122 def get_current_token_for_writer(self, instance_name: str) -> int: 123 """Returns the position of the given writer. 124 125 For streams with single writers this is equivalent to `get_current_token`. 126 """ 127 raise NotImplementedError() 128 129 130class AbstractStreamIdGenerator(AbstractStreamIdTracker): 131 """Generates stream IDs for a stream that may have multiple writers. 132 133 Each stream ID represents a write transaction, whose completion is tracked 134 so that the "current" stream ID of the stream can be determined. 135 136 See `AbstractStreamIdTracker` for more details. 137 """ 138 139 @abc.abstractmethod 140 def get_next(self) -> AsyncContextManager[int]: 141 """ 142 Usage: 143 async with stream_id_gen.get_next() as stream_id: 144 # ... persist event ... 145 """ 146 raise NotImplementedError() 147 148 @abc.abstractmethod 149 def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: 150 """ 151 Usage: 152 async with stream_id_gen.get_next(n) as stream_ids: 153 # ... persist events ... 154 """ 155 raise NotImplementedError() 156 157 158class StreamIdGenerator(AbstractStreamIdGenerator): 159 """Generates and tracks stream IDs for a stream with a single writer. 160 161 This class must only be used when the current Synapse process is the sole 162 writer for a stream. 163 164 Args: 165 db_conn(connection): A database connection to use to fetch the 166 initial value of the generator from. 167 table(str): A database table to read the initial value of the id 168 generator from. 169 column(str): The column of the database table to read the initial 170 value from the id generator from. 171 extra_tables(list): List of pairs of database tables and columns to 172 use to source the initial value of the generator from. The value 173 with the largest magnitude is used. 174 step(int): which direction the stream ids grow in. +1 to grow 175 upwards, -1 to grow downwards. 176 177 Usage: 178 async with stream_id_gen.get_next() as stream_id: 179 # ... persist event ... 180 """ 181 182 def __init__( 183 self, 184 db_conn: LoggingDatabaseConnection, 185 table: str, 186 column: str, 187 extra_tables: Iterable[Tuple[str, str]] = (), 188 step: int = 1, 189 ) -> None: 190 assert step != 0 191 self._lock = threading.Lock() 192 self._step: int = step 193 self._current: int = _load_current_id(db_conn, table, column, step) 194 for table, column in extra_tables: 195 self._current = (max if step > 0 else min)( 196 self._current, _load_current_id(db_conn, table, column, step) 197 ) 198 199 # We use this as an ordered set, as we want to efficiently append items, 200 # remove items and get the first item. Since we insert IDs in order, the 201 # insertion ordering will ensure its in the correct ordering. 202 # 203 # The key and values are the same, but we never look at the values. 204 self._unfinished_ids: OrderedDict[int, int] = OrderedDict() 205 206 def advance(self, instance_name: str, new_id: int) -> None: 207 # `StreamIdGenerator` should only be used when there is a single writer, 208 # so replication should never happen. 209 raise Exception("Replication is not supported by StreamIdGenerator") 210 211 def get_next(self) -> AsyncContextManager[int]: 212 with self._lock: 213 self._current += self._step 214 next_id = self._current 215 216 self._unfinished_ids[next_id] = next_id 217 218 @contextmanager 219 def manager() -> Generator[int, None, None]: 220 try: 221 yield next_id 222 finally: 223 with self._lock: 224 self._unfinished_ids.pop(next_id) 225 226 return _AsyncCtxManagerWrapper(manager()) 227 228 def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: 229 with self._lock: 230 next_ids = range( 231 self._current + self._step, 232 self._current + self._step * (n + 1), 233 self._step, 234 ) 235 self._current += n * self._step 236 237 for next_id in next_ids: 238 self._unfinished_ids[next_id] = next_id 239 240 @contextmanager 241 def manager() -> Generator[Sequence[int], None, None]: 242 try: 243 yield next_ids 244 finally: 245 with self._lock: 246 for next_id in next_ids: 247 self._unfinished_ids.pop(next_id) 248 249 return _AsyncCtxManagerWrapper(manager()) 250 251 def get_current_token(self) -> int: 252 with self._lock: 253 if self._unfinished_ids: 254 return next(iter(self._unfinished_ids)) - self._step 255 256 return self._current 257 258 def get_current_token_for_writer(self, instance_name: str) -> int: 259 return self.get_current_token() 260 261 262class MultiWriterIdGenerator(AbstractStreamIdGenerator): 263 """Generates and tracks stream IDs for a stream with multiple writers. 264 265 Uses a Postgres sequence to coordinate ID assignment, but positions of other 266 writers will only get updated when `advance` is called (by replication). 267 268 Note: Only works with Postgres. 269 270 Args: 271 db_conn 272 db 273 stream_name: A name for the stream, for use in the `stream_positions` 274 table. (Does not need to be the same as the replication stream name) 275 instance_name: The name of this instance. 276 tables: List of tables associated with the stream. Tuple of table 277 name, column name that stores the writer's instance name, and 278 column name that stores the stream ID. 279 sequence_name: The name of the postgres sequence used to generate new 280 IDs. 281 writers: A list of known writers to use to populate current positions 282 on startup. Can be empty if nothing uses `get_current_token` or 283 `get_positions` (e.g. caches stream). 284 positive: Whether the IDs are positive (true) or negative (false). 285 When using negative IDs we go backwards from -1 to -2, -3, etc. 286 """ 287 288 def __init__( 289 self, 290 db_conn: LoggingDatabaseConnection, 291 db: DatabasePool, 292 stream_name: str, 293 instance_name: str, 294 tables: List[Tuple[str, str, str]], 295 sequence_name: str, 296 writers: List[str], 297 positive: bool = True, 298 ) -> None: 299 self._db = db 300 self._stream_name = stream_name 301 self._instance_name = instance_name 302 self._positive = positive 303 self._writers = writers 304 self._return_factor = 1 if positive else -1 305 306 # We lock as some functions may be called from DB threads. 307 self._lock = threading.Lock() 308 309 # Note: If we are a negative stream then we still store all the IDs as 310 # positive to make life easier for us, and simply negate the IDs when we 311 # return them. 312 self._current_positions: Dict[str, int] = {} 313 314 # Set of local IDs that we're still processing. The current position 315 # should be less than the minimum of this set (if not empty). 316 self._unfinished_ids: SortedSet[int] = SortedSet() 317 318 # We also need to track when we've requested some new stream IDs but 319 # they haven't yet been added to the `_unfinished_ids` set. Every time 320 # we request a new stream ID we add the current max stream ID to the 321 # list, and remove it once we've added the newly allocated IDs to the 322 # `_unfinished_ids` set. This means that we *may* be allocated stream 323 # IDs above those in the list, and so we can't advance the local current 324 # position beyond the minimum stream ID in this list. 325 self._in_flight_fetches: SortedList[int] = SortedList() 326 327 # Set of local IDs that we've processed that are larger than the current 328 # position, due to there being smaller unpersisted IDs. 329 self._finished_ids: Set[int] = set() 330 331 # We track the max position where we know everything before has been 332 # persisted. This is done by a) looking at the min across all instances 333 # and b) noting that if we have seen a run of persisted positions 334 # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7). 335 # 336 # Note: There is no guarantee that the IDs generated by the sequence 337 # will be gapless; gaps can form when e.g. a transaction was rolled 338 # back. This means that sometimes we won't be able to skip forward the 339 # position even though everything has been persisted. However, since 340 # gaps should be relatively rare it's still worth doing the book keeping 341 # that allows us to skip forwards when there are gapless runs of 342 # positions. 343 # 344 # We start at 1 here as a) the first generated stream ID will be 2, and 345 # b) other parts of the code assume that stream IDs are strictly greater 346 # than 0. 347 self._persisted_upto_position = ( 348 min(self._current_positions.values()) if self._current_positions else 1 349 ) 350 self._known_persisted_positions: List[int] = [] 351 352 # The maximum stream ID that we have seen been allocated across any writer. 353 self._max_seen_allocated_stream_id = 1 354 355 self._sequence_gen = PostgresSequenceGenerator(sequence_name) 356 357 # We check that the table and sequence haven't diverged. 358 for table, _, id_column in tables: 359 self._sequence_gen.check_consistency( 360 db_conn, 361 table=table, 362 id_column=id_column, 363 stream_name=stream_name, 364 positive=positive, 365 ) 366 367 # This goes and fills out the above state from the database. 368 self._load_current_ids(db_conn, tables) 369 370 self._max_seen_allocated_stream_id = max( 371 self._current_positions.values(), default=1 372 ) 373 374 def _load_current_ids( 375 self, 376 db_conn: LoggingDatabaseConnection, 377 tables: List[Tuple[str, str, str]], 378 ) -> None: 379 cur = db_conn.cursor(txn_name="_load_current_ids") 380 381 # Load the current positions of all writers for the stream. 382 if self._writers: 383 # We delete any stale entries in the positions table. This is 384 # important if we add back a writer after a long time; we want to 385 # consider that a "new" writer, rather than using the old stale 386 # entry here. 387 sql = """ 388 DELETE FROM stream_positions 389 WHERE 390 stream_name = ? 391 AND instance_name != ALL(?) 392 """ 393 cur.execute(sql, (self._stream_name, self._writers)) 394 395 sql = """ 396 SELECT instance_name, stream_id FROM stream_positions 397 WHERE stream_name = ? 398 """ 399 cur.execute(sql, (self._stream_name,)) 400 401 self._current_positions = { 402 instance: stream_id * self._return_factor 403 for instance, stream_id in cur 404 if instance in self._writers 405 } 406 407 # We set the `_persisted_upto_position` to be the minimum of all current 408 # positions. If empty we use the max stream ID from the DB table. 409 min_stream_id = min(self._current_positions.values(), default=None) 410 411 if min_stream_id is None: 412 # We add a GREATEST here to ensure that the result is always 413 # positive. (This can be a problem for e.g. backfill streams where 414 # the server has never backfilled). 415 max_stream_id = 1 416 for table, _, id_column in tables: 417 sql = """ 418 SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) 419 FROM %(table)s 420 """ % { 421 "id": id_column, 422 "table": table, 423 "agg": "MAX" if self._positive else "-MIN", 424 } 425 cur.execute(sql) 426 result = cur.fetchone() 427 assert result is not None 428 (stream_id,) = result 429 430 max_stream_id = max(max_stream_id, stream_id) 431 432 self._persisted_upto_position = max_stream_id 433 else: 434 # If we have a min_stream_id then we pull out everything greater 435 # than it from the DB so that we can prefill 436 # `_known_persisted_positions` and get a more accurate 437 # `_persisted_upto_position`. 438 # 439 # We also check if any of the later rows are from this instance, in 440 # which case we use that for this instance's current position. This 441 # is to handle the case where we didn't finish persisting to the 442 # stream positions table before restart (or the stream position 443 # table otherwise got out of date). 444 445 self._persisted_upto_position = min_stream_id 446 447 rows: List[Tuple[str, int]] = [] 448 for table, instance_column, id_column in tables: 449 sql = """ 450 SELECT %(instance)s, %(id)s FROM %(table)s 451 WHERE ? %(cmp)s %(id)s 452 """ % { 453 "id": id_column, 454 "table": table, 455 "instance": instance_column, 456 "cmp": "<=" if self._positive else ">=", 457 } 458 cur.execute(sql, (min_stream_id * self._return_factor,)) 459 460 # Cast safety: this corresponds to the types returned by the query above. 461 rows.extend(cast(Iterable[Tuple[str, int]], cur)) 462 463 # Sort so that we handle rows in order for each instance. 464 rows.sort() 465 466 with self._lock: 467 for ( 468 instance, 469 stream_id, 470 ) in rows: 471 stream_id = self._return_factor * stream_id 472 self._add_persisted_position(stream_id) 473 474 if instance == self._instance_name: 475 self._current_positions[instance] = stream_id 476 477 cur.close() 478 479 def _load_next_id_txn(self, txn: Cursor) -> int: 480 stream_ids = self._load_next_mult_id_txn(txn, 1) 481 return stream_ids[0] 482 483 def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]: 484 # We need to track that we've requested some more stream IDs, and what 485 # the current max allocated stream ID is. This is to prevent a race 486 # where we've been allocated stream IDs but they have not yet been added 487 # to the `_unfinished_ids` set, allowing the current position to advance 488 # past them. 489 with self._lock: 490 current_max = self._max_seen_allocated_stream_id 491 self._in_flight_fetches.add(current_max) 492 493 try: 494 stream_ids = self._sequence_gen.get_next_mult_txn(txn, n) 495 496 with self._lock: 497 self._unfinished_ids.update(stream_ids) 498 self._max_seen_allocated_stream_id = max( 499 self._max_seen_allocated_stream_id, self._unfinished_ids[-1] 500 ) 501 finally: 502 with self._lock: 503 self._in_flight_fetches.remove(current_max) 504 505 return stream_ids 506 507 def get_next(self) -> AsyncContextManager[int]: 508 # If we have a list of instances that are allowed to write to this 509 # stream, make sure we're in it. 510 if self._writers and self._instance_name not in self._writers: 511 raise Exception("Tried to allocate stream ID on non-writer") 512 513 # Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids, 514 # controls the return type. If `None` or omitted, the context manager yields 515 # a single integer stream_id; otherwise it yields a list of stream_ids. 516 return cast(AsyncContextManager[int], _MultiWriterCtxManager(self)) 517 518 def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]: 519 # If we have a list of instances that are allowed to write to this 520 # stream, make sure we're in it. 521 if self._writers and self._instance_name not in self._writers: 522 raise Exception("Tried to allocate stream ID on non-writer") 523 524 # Cast safety: see get_next. 525 return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n)) 526 527 def get_next_txn(self, txn: LoggingTransaction) -> int: 528 """ 529 Usage: 530 531 stream_id = stream_id_gen.get_next(txn) 532 # ... persist event ... 533 """ 534 535 # If we have a list of instances that are allowed to write to this 536 # stream, make sure we're in it. 537 if self._writers and self._instance_name not in self._writers: 538 raise Exception("Tried to allocate stream ID on non-writer") 539 540 next_id = self._load_next_id_txn(txn) 541 542 txn.call_after(self._mark_id_as_finished, next_id) 543 txn.call_on_exception(self._mark_id_as_finished, next_id) 544 545 # Update the `stream_positions` table with newly updated stream 546 # ID (unless self._writers is not set in which case we don't 547 # bother, as nothing will read it). 548 # 549 # We only do this on the success path so that the persisted current 550 # position points to a persisted row with the correct instance name. 551 if self._writers: 552 txn.call_after( 553 run_as_background_process, 554 "MultiWriterIdGenerator._update_table", 555 self._db.runInteraction, 556 "MultiWriterIdGenerator._update_table", 557 self._update_stream_positions_table_txn, 558 ) 559 560 return self._return_factor * next_id 561 562 def _mark_id_as_finished(self, next_id: int) -> None: 563 """The ID has finished being processed so we should advance the 564 current position if possible. 565 """ 566 567 with self._lock: 568 self._unfinished_ids.discard(next_id) 569 self._finished_ids.add(next_id) 570 571 new_cur: Optional[int] = None 572 573 if self._unfinished_ids or self._in_flight_fetches: 574 # If there are unfinished IDs then the new position will be the 575 # largest finished ID strictly less than the minimum unfinished 576 # ID. 577 578 # The minimum unfinished ID needs to take account of both 579 # `_unfinished_ids` and `_in_flight_fetches`. 580 if self._unfinished_ids and self._in_flight_fetches: 581 # `_in_flight_fetches` stores the maximum safe stream ID, so 582 # we add one to make it equivalent to the minimum unsafe ID. 583 min_unfinished = min( 584 self._unfinished_ids[0], self._in_flight_fetches[0] + 1 585 ) 586 elif self._in_flight_fetches: 587 min_unfinished = self._in_flight_fetches[0] + 1 588 else: 589 min_unfinished = self._unfinished_ids[0] 590 591 finished = set() 592 for s in self._finished_ids: 593 if s < min_unfinished: 594 if new_cur is None or new_cur < s: 595 new_cur = s 596 else: 597 finished.add(s) 598 599 # We clear these out since they're now all less than the new 600 # position. 601 self._finished_ids = finished 602 else: 603 # There are no unfinished IDs so the new position is simply the 604 # largest finished one. 605 new_cur = max(self._finished_ids) 606 607 # We clear these out since they're now all less than the new 608 # position. 609 self._finished_ids.clear() 610 611 if new_cur: 612 curr = self._current_positions.get(self._instance_name, 0) 613 self._current_positions[self._instance_name] = max(curr, new_cur) 614 615 self._add_persisted_position(next_id) 616 617 def get_current_token(self) -> int: 618 return self.get_persisted_upto_position() 619 620 def get_current_token_for_writer(self, instance_name: str) -> int: 621 # If we don't have an entry for the given instance name, we assume it's a 622 # new writer. 623 # 624 # For new writers we assume their initial position to be the current 625 # persisted up to position. This stops Synapse from doing a full table 626 # scan when a new writer announces itself over replication. 627 with self._lock: 628 return self._return_factor * self._current_positions.get( 629 instance_name, self._persisted_upto_position 630 ) 631 632 def get_positions(self) -> Dict[str, int]: 633 """Get a copy of the current positon map. 634 635 Note that this won't necessarily include all configured writers if some 636 writers haven't written anything yet. 637 """ 638 639 with self._lock: 640 return { 641 name: self._return_factor * i 642 for name, i in self._current_positions.items() 643 } 644 645 def advance(self, instance_name: str, new_id: int) -> None: 646 new_id *= self._return_factor 647 648 with self._lock: 649 self._current_positions[instance_name] = max( 650 new_id, self._current_positions.get(instance_name, 0) 651 ) 652 653 self._max_seen_allocated_stream_id = max( 654 self._max_seen_allocated_stream_id, new_id 655 ) 656 657 self._add_persisted_position(new_id) 658 659 def get_persisted_upto_position(self) -> int: 660 """Get the max position where all previous positions have been 661 persisted. 662 663 Note: In the worst case scenario this will be equal to the minimum 664 position across writers. This means that the returned position here can 665 lag if one writer doesn't write very often. 666 """ 667 668 with self._lock: 669 return self._return_factor * self._persisted_upto_position 670 671 def _add_persisted_position(self, new_id: int) -> None: 672 """Record that we have persisted a position. 673 674 This is used to keep the `_current_positions` up to date. 675 """ 676 677 # We require that the lock is locked by caller 678 assert self._lock.locked() 679 680 heapq.heappush(self._known_persisted_positions, new_id) 681 682 # If we're a writer and we don't have any active writes we update our 683 # current position to the latest position seen. This allows the instance 684 # to report a recent position when asked, rather than a potentially old 685 # one (if this instance hasn't written anything for a while). 686 our_current_position = self._current_positions.get(self._instance_name) 687 if ( 688 our_current_position 689 and not self._unfinished_ids 690 and not self._in_flight_fetches 691 ): 692 self._current_positions[self._instance_name] = max( 693 our_current_position, new_id 694 ) 695 696 # We move the current min position up if the minimum current positions 697 # of all instances is higher (since by definition all positions less 698 # that that have been persisted). 699 min_curr = min(self._current_positions.values(), default=0) 700 self._persisted_upto_position = max(min_curr, self._persisted_upto_position) 701 702 # We now iterate through the seen positions, discarding those that are 703 # less than the current min positions, and incrementing the min position 704 # if its exactly one greater. 705 # 706 # This is also where we discard items from `_known_persisted_positions` 707 # (to ensure the list doesn't infinitely grow). 708 while self._known_persisted_positions: 709 if self._known_persisted_positions[0] <= self._persisted_upto_position: 710 heapq.heappop(self._known_persisted_positions) 711 elif ( 712 self._known_persisted_positions[0] == self._persisted_upto_position + 1 713 ): 714 heapq.heappop(self._known_persisted_positions) 715 self._persisted_upto_position += 1 716 else: 717 # There was a gap in seen positions, so there is nothing more to 718 # do. 719 break 720 721 def _update_stream_positions_table_txn(self, txn: Cursor) -> None: 722 """Update the `stream_positions` table with newly persisted position.""" 723 724 if not self._writers: 725 return 726 727 # We upsert the value, ensuring on conflict that we always increase the 728 # value (or decrease if stream goes backwards). 729 sql = """ 730 INSERT INTO stream_positions (stream_name, instance_name, stream_id) 731 VALUES (?, ?, ?) 732 ON CONFLICT (stream_name, instance_name) 733 DO UPDATE SET 734 stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id) 735 """ % { 736 "agg": "GREATEST" if self._positive else "LEAST", 737 } 738 739 pos = (self.get_current_token_for_writer(self._instance_name),) 740 txn.execute(sql, (self._stream_name, self._instance_name, pos)) 741 742 743@attr.s(frozen=True, auto_attribs=True) 744class _AsyncCtxManagerWrapper(Generic[T]): 745 """Helper class to convert a plain context manager to an async one. 746 747 This is mainly useful if you have a plain context manager but the interface 748 requires an async one. 749 """ 750 751 inner: ContextManager[T] 752 753 async def __aenter__(self) -> T: 754 return self.inner.__enter__() 755 756 async def __aexit__( 757 self, 758 exc_type: Optional[Type[BaseException]], 759 exc: Optional[BaseException], 760 tb: Optional[TracebackType], 761 ) -> Optional[bool]: 762 return self.inner.__exit__(exc_type, exc, tb) 763 764 765@attr.s(slots=True) 766class _MultiWriterCtxManager: 767 """Async context manager returned by MultiWriterIdGenerator""" 768 769 id_gen = attr.ib(type=MultiWriterIdGenerator) 770 multiple_ids = attr.ib(type=Optional[int], default=None) 771 stream_ids = attr.ib(type=List[int], factory=list) 772 773 async def __aenter__(self) -> Union[int, List[int]]: 774 # It's safe to run this in autocommit mode as fetching values from a 775 # sequence ignores transaction semantics anyway. 776 self.stream_ids = await self.id_gen._db.runInteraction( 777 "_load_next_mult_id", 778 self.id_gen._load_next_mult_id_txn, 779 self.multiple_ids or 1, 780 db_autocommit=True, 781 ) 782 783 if self.multiple_ids is None: 784 return self.stream_ids[0] * self.id_gen._return_factor 785 else: 786 return [i * self.id_gen._return_factor for i in self.stream_ids] 787 788 async def __aexit__( 789 self, 790 exc_type: Optional[Type[BaseException]], 791 exc: Optional[BaseException], 792 tb: Optional[TracebackType], 793 ) -> bool: 794 for i in self.stream_ids: 795 self.id_gen._mark_id_as_finished(i) 796 797 if exc_type is not None: 798 return False 799 800 # Update the `stream_positions` table with newly updated stream 801 # ID (unless self._writers is not set in which case we don't 802 # bother, as nothing will read it). 803 # 804 # We only do this on the success path so that the persisted current 805 # position points to a persisted row with the correct instance name. 806 # 807 # We do this in autocommit mode as a) the upsert works correctly outside 808 # transactions and b) reduces the amount of time the rows are locked 809 # for. If we don't do this then we'll often hit serialization errors due 810 # to the fact we default to REPEATABLE READ isolation levels. 811 if self.id_gen._writers: 812 await self.id_gen._db.runInteraction( 813 "MultiWriterIdGenerator._update_table", 814 self.id_gen._update_stream_positions_table_txn, 815 db_autocommit=True, 816 ) 817 818 return False 819