1# Copyright 2014-2016 OpenMarket Ltd
2# Copyright 2021 The Matrix.org Foundation C.I.C.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15import abc
16import heapq
17import logging
18import threading
19from collections import OrderedDict
20from contextlib import contextmanager
21from types import TracebackType
22from typing import (
23    AsyncContextManager,
24    ContextManager,
25    Dict,
26    Generator,
27    Generic,
28    Iterable,
29    List,
30    Optional,
31    Sequence,
32    Set,
33    Tuple,
34    Type,
35    TypeVar,
36    Union,
37    cast,
38)
39
40import attr
41from sortedcontainers import SortedList, SortedSet
42
43from synapse.metrics.background_process_metrics import run_as_background_process
44from synapse.storage.database import (
45    DatabasePool,
46    LoggingDatabaseConnection,
47    LoggingTransaction,
48)
49from synapse.storage.types import Cursor
50from synapse.storage.util.sequence import PostgresSequenceGenerator
51
52logger = logging.getLogger(__name__)
53
54
55T = TypeVar("T")
56
57
58class IdGenerator:
59    def __init__(
60        self,
61        db_conn: LoggingDatabaseConnection,
62        table: str,
63        column: str,
64    ):
65        self._lock = threading.Lock()
66        self._next_id = _load_current_id(db_conn, table, column)
67
68    def get_next(self) -> int:
69        with self._lock:
70            self._next_id += 1
71            return self._next_id
72
73
74def _load_current_id(
75    db_conn: LoggingDatabaseConnection, table: str, column: str, step: int = 1
76) -> int:
77    cur = db_conn.cursor(txn_name="_load_current_id")
78    if step == 1:
79        cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
80    else:
81        cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
82    result = cur.fetchone()
83    assert result is not None
84    (val,) = result
85    cur.close()
86    current_id = int(val) if val else step
87    res = (max if step > 0 else min)(current_id, step)
88    logger.info("Initialising stream generator for %s(%s): %i", table, column, res)
89    return res
90
91
92class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
93    """Tracks the "current" stream ID of a stream that may have multiple writers.
94
95    Stream IDs are monotonically increasing or decreasing integers representing write
96    transactions. The "current" stream ID is the stream ID such that all transactions
97    with equal or smaller stream IDs have completed. Since transactions may complete out
98    of order, this is not the same as the stream ID of the last completed transaction.
99
100    Completed transactions include both committed transactions and transactions that
101    have been rolled back.
102    """
103
104    @abc.abstractmethod
105    def advance(self, instance_name: str, new_id: int) -> None:
106        """Advance the position of the named writer to the given ID, if greater
107        than existing entry.
108        """
109        raise NotImplementedError()
110
111    @abc.abstractmethod
112    def get_current_token(self) -> int:
113        """Returns the maximum stream id such that all stream ids less than or
114        equal to it have been successfully persisted.
115
116        Returns:
117            The maximum stream id.
118        """
119        raise NotImplementedError()
120
121    @abc.abstractmethod
122    def get_current_token_for_writer(self, instance_name: str) -> int:
123        """Returns the position of the given writer.
124
125        For streams with single writers this is equivalent to `get_current_token`.
126        """
127        raise NotImplementedError()
128
129
130class AbstractStreamIdGenerator(AbstractStreamIdTracker):
131    """Generates stream IDs for a stream that may have multiple writers.
132
133    Each stream ID represents a write transaction, whose completion is tracked
134    so that the "current" stream ID of the stream can be determined.
135
136    See `AbstractStreamIdTracker` for more details.
137    """
138
139    @abc.abstractmethod
140    def get_next(self) -> AsyncContextManager[int]:
141        """
142        Usage:
143            async with stream_id_gen.get_next() as stream_id:
144                # ... persist event ...
145        """
146        raise NotImplementedError()
147
148    @abc.abstractmethod
149    def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
150        """
151        Usage:
152            async with stream_id_gen.get_next(n) as stream_ids:
153                # ... persist events ...
154        """
155        raise NotImplementedError()
156
157
158class StreamIdGenerator(AbstractStreamIdGenerator):
159    """Generates and tracks stream IDs for a stream with a single writer.
160
161    This class must only be used when the current Synapse process is the sole
162    writer for a stream.
163
164    Args:
165        db_conn(connection):  A database connection to use to fetch the
166            initial value of the generator from.
167        table(str): A database table to read the initial value of the id
168            generator from.
169        column(str): The column of the database table to read the initial
170            value from the id generator from.
171        extra_tables(list): List of pairs of database tables and columns to
172            use to source the initial value of the generator from. The value
173            with the largest magnitude is used.
174        step(int): which direction the stream ids grow in. +1 to grow
175            upwards, -1 to grow downwards.
176
177    Usage:
178        async with stream_id_gen.get_next() as stream_id:
179            # ... persist event ...
180    """
181
182    def __init__(
183        self,
184        db_conn: LoggingDatabaseConnection,
185        table: str,
186        column: str,
187        extra_tables: Iterable[Tuple[str, str]] = (),
188        step: int = 1,
189    ) -> None:
190        assert step != 0
191        self._lock = threading.Lock()
192        self._step: int = step
193        self._current: int = _load_current_id(db_conn, table, column, step)
194        for table, column in extra_tables:
195            self._current = (max if step > 0 else min)(
196                self._current, _load_current_id(db_conn, table, column, step)
197            )
198
199        # We use this as an ordered set, as we want to efficiently append items,
200        # remove items and get the first item. Since we insert IDs in order, the
201        # insertion ordering will ensure its in the correct ordering.
202        #
203        # The key and values are the same, but we never look at the values.
204        self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
205
206    def advance(self, instance_name: str, new_id: int) -> None:
207        # `StreamIdGenerator` should only be used when there is a single writer,
208        # so replication should never happen.
209        raise Exception("Replication is not supported by StreamIdGenerator")
210
211    def get_next(self) -> AsyncContextManager[int]:
212        with self._lock:
213            self._current += self._step
214            next_id = self._current
215
216            self._unfinished_ids[next_id] = next_id
217
218        @contextmanager
219        def manager() -> Generator[int, None, None]:
220            try:
221                yield next_id
222            finally:
223                with self._lock:
224                    self._unfinished_ids.pop(next_id)
225
226        return _AsyncCtxManagerWrapper(manager())
227
228    def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
229        with self._lock:
230            next_ids = range(
231                self._current + self._step,
232                self._current + self._step * (n + 1),
233                self._step,
234            )
235            self._current += n * self._step
236
237            for next_id in next_ids:
238                self._unfinished_ids[next_id] = next_id
239
240        @contextmanager
241        def manager() -> Generator[Sequence[int], None, None]:
242            try:
243                yield next_ids
244            finally:
245                with self._lock:
246                    for next_id in next_ids:
247                        self._unfinished_ids.pop(next_id)
248
249        return _AsyncCtxManagerWrapper(manager())
250
251    def get_current_token(self) -> int:
252        with self._lock:
253            if self._unfinished_ids:
254                return next(iter(self._unfinished_ids)) - self._step
255
256            return self._current
257
258    def get_current_token_for_writer(self, instance_name: str) -> int:
259        return self.get_current_token()
260
261
262class MultiWriterIdGenerator(AbstractStreamIdGenerator):
263    """Generates and tracks stream IDs for a stream with multiple writers.
264
265    Uses a Postgres sequence to coordinate ID assignment, but positions of other
266    writers will only get updated when `advance` is called (by replication).
267
268    Note: Only works with Postgres.
269
270    Args:
271        db_conn
272        db
273        stream_name: A name for the stream, for use in the `stream_positions`
274            table. (Does not need to be the same as the replication stream name)
275        instance_name: The name of this instance.
276        tables: List of tables associated with the stream. Tuple of table
277            name, column name that stores the writer's instance name, and
278            column name that stores the stream ID.
279        sequence_name: The name of the postgres sequence used to generate new
280            IDs.
281        writers: A list of known writers to use to populate current positions
282            on startup. Can be empty if nothing uses `get_current_token` or
283            `get_positions` (e.g. caches stream).
284        positive: Whether the IDs are positive (true) or negative (false).
285            When using negative IDs we go backwards from -1 to -2, -3, etc.
286    """
287
288    def __init__(
289        self,
290        db_conn: LoggingDatabaseConnection,
291        db: DatabasePool,
292        stream_name: str,
293        instance_name: str,
294        tables: List[Tuple[str, str, str]],
295        sequence_name: str,
296        writers: List[str],
297        positive: bool = True,
298    ) -> None:
299        self._db = db
300        self._stream_name = stream_name
301        self._instance_name = instance_name
302        self._positive = positive
303        self._writers = writers
304        self._return_factor = 1 if positive else -1
305
306        # We lock as some functions may be called from DB threads.
307        self._lock = threading.Lock()
308
309        # Note: If we are a negative stream then we still store all the IDs as
310        # positive to make life easier for us, and simply negate the IDs when we
311        # return them.
312        self._current_positions: Dict[str, int] = {}
313
314        # Set of local IDs that we're still processing. The current position
315        # should be less than the minimum of this set (if not empty).
316        self._unfinished_ids: SortedSet[int] = SortedSet()
317
318        # We also need to track when we've requested some new stream IDs but
319        # they haven't yet been added to the `_unfinished_ids` set. Every time
320        # we request a new stream ID we add the current max stream ID to the
321        # list, and remove it once we've added the newly allocated IDs to the
322        # `_unfinished_ids` set. This means that we *may* be allocated stream
323        # IDs above those in the list, and so we can't advance the local current
324        # position beyond the minimum stream ID in this list.
325        self._in_flight_fetches: SortedList[int] = SortedList()
326
327        # Set of local IDs that we've processed that are larger than the current
328        # position, due to there being smaller unpersisted IDs.
329        self._finished_ids: Set[int] = set()
330
331        # We track the max position where we know everything before has been
332        # persisted. This is done by a) looking at the min across all instances
333        # and b) noting that if we have seen a run of persisted positions
334        # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
335        #
336        # Note: There is no guarantee that the IDs generated by the sequence
337        # will be gapless; gaps can form when e.g. a transaction was rolled
338        # back. This means that sometimes we won't be able to skip forward the
339        # position even though everything has been persisted. However, since
340        # gaps should be relatively rare it's still worth doing the book keeping
341        # that allows us to skip forwards when there are gapless runs of
342        # positions.
343        #
344        # We start at 1 here as a) the first generated stream ID will be 2, and
345        # b) other parts of the code assume that stream IDs are strictly greater
346        # than 0.
347        self._persisted_upto_position = (
348            min(self._current_positions.values()) if self._current_positions else 1
349        )
350        self._known_persisted_positions: List[int] = []
351
352        # The maximum stream ID that we have seen been allocated across any writer.
353        self._max_seen_allocated_stream_id = 1
354
355        self._sequence_gen = PostgresSequenceGenerator(sequence_name)
356
357        # We check that the table and sequence haven't diverged.
358        for table, _, id_column in tables:
359            self._sequence_gen.check_consistency(
360                db_conn,
361                table=table,
362                id_column=id_column,
363                stream_name=stream_name,
364                positive=positive,
365            )
366
367        # This goes and fills out the above state from the database.
368        self._load_current_ids(db_conn, tables)
369
370        self._max_seen_allocated_stream_id = max(
371            self._current_positions.values(), default=1
372        )
373
374    def _load_current_ids(
375        self,
376        db_conn: LoggingDatabaseConnection,
377        tables: List[Tuple[str, str, str]],
378    ) -> None:
379        cur = db_conn.cursor(txn_name="_load_current_ids")
380
381        # Load the current positions of all writers for the stream.
382        if self._writers:
383            # We delete any stale entries in the positions table. This is
384            # important if we add back a writer after a long time; we want to
385            # consider that a "new" writer, rather than using the old stale
386            # entry here.
387            sql = """
388                DELETE FROM stream_positions
389                WHERE
390                    stream_name = ?
391                    AND instance_name != ALL(?)
392            """
393            cur.execute(sql, (self._stream_name, self._writers))
394
395            sql = """
396                SELECT instance_name, stream_id FROM stream_positions
397                WHERE stream_name = ?
398            """
399            cur.execute(sql, (self._stream_name,))
400
401            self._current_positions = {
402                instance: stream_id * self._return_factor
403                for instance, stream_id in cur
404                if instance in self._writers
405            }
406
407        # We set the `_persisted_upto_position` to be the minimum of all current
408        # positions. If empty we use the max stream ID from the DB table.
409        min_stream_id = min(self._current_positions.values(), default=None)
410
411        if min_stream_id is None:
412            # We add a GREATEST here to ensure that the result is always
413            # positive. (This can be a problem for e.g. backfill streams where
414            # the server has never backfilled).
415            max_stream_id = 1
416            for table, _, id_column in tables:
417                sql = """
418                    SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
419                    FROM %(table)s
420                """ % {
421                    "id": id_column,
422                    "table": table,
423                    "agg": "MAX" if self._positive else "-MIN",
424                }
425                cur.execute(sql)
426                result = cur.fetchone()
427                assert result is not None
428                (stream_id,) = result
429
430                max_stream_id = max(max_stream_id, stream_id)
431
432            self._persisted_upto_position = max_stream_id
433        else:
434            # If we have a min_stream_id then we pull out everything greater
435            # than it from the DB so that we can prefill
436            # `_known_persisted_positions` and get a more accurate
437            # `_persisted_upto_position`.
438            #
439            # We also check if any of the later rows are from this instance, in
440            # which case we use that for this instance's current position. This
441            # is to handle the case where we didn't finish persisting to the
442            # stream positions table before restart (or the stream position
443            # table otherwise got out of date).
444
445            self._persisted_upto_position = min_stream_id
446
447            rows: List[Tuple[str, int]] = []
448            for table, instance_column, id_column in tables:
449                sql = """
450                    SELECT %(instance)s, %(id)s FROM %(table)s
451                    WHERE ? %(cmp)s %(id)s
452                """ % {
453                    "id": id_column,
454                    "table": table,
455                    "instance": instance_column,
456                    "cmp": "<=" if self._positive else ">=",
457                }
458                cur.execute(sql, (min_stream_id * self._return_factor,))
459
460                # Cast safety: this corresponds to the types returned by the query above.
461                rows.extend(cast(Iterable[Tuple[str, int]], cur))
462
463            # Sort so that we handle rows in order for each instance.
464            rows.sort()
465
466            with self._lock:
467                for (
468                    instance,
469                    stream_id,
470                ) in rows:
471                    stream_id = self._return_factor * stream_id
472                    self._add_persisted_position(stream_id)
473
474                    if instance == self._instance_name:
475                        self._current_positions[instance] = stream_id
476
477        cur.close()
478
479    def _load_next_id_txn(self, txn: Cursor) -> int:
480        stream_ids = self._load_next_mult_id_txn(txn, 1)
481        return stream_ids[0]
482
483    def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]:
484        # We need to track that we've requested some more stream IDs, and what
485        # the current max allocated stream ID is. This is to prevent a race
486        # where we've been allocated stream IDs but they have not yet been added
487        # to the `_unfinished_ids` set, allowing the current position to advance
488        # past them.
489        with self._lock:
490            current_max = self._max_seen_allocated_stream_id
491            self._in_flight_fetches.add(current_max)
492
493        try:
494            stream_ids = self._sequence_gen.get_next_mult_txn(txn, n)
495
496            with self._lock:
497                self._unfinished_ids.update(stream_ids)
498                self._max_seen_allocated_stream_id = max(
499                    self._max_seen_allocated_stream_id, self._unfinished_ids[-1]
500                )
501        finally:
502            with self._lock:
503                self._in_flight_fetches.remove(current_max)
504
505        return stream_ids
506
507    def get_next(self) -> AsyncContextManager[int]:
508        # If we have a list of instances that are allowed to write to this
509        # stream, make sure we're in it.
510        if self._writers and self._instance_name not in self._writers:
511            raise Exception("Tried to allocate stream ID on non-writer")
512
513        # Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids,
514        # controls the return type. If `None` or omitted, the context manager yields
515        # a single integer stream_id; otherwise it yields a list of stream_ids.
516        return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
517
518    def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
519        # If we have a list of instances that are allowed to write to this
520        # stream, make sure we're in it.
521        if self._writers and self._instance_name not in self._writers:
522            raise Exception("Tried to allocate stream ID on non-writer")
523
524        # Cast safety: see get_next.
525        return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n))
526
527    def get_next_txn(self, txn: LoggingTransaction) -> int:
528        """
529        Usage:
530
531            stream_id = stream_id_gen.get_next(txn)
532            # ... persist event ...
533        """
534
535        # If we have a list of instances that are allowed to write to this
536        # stream, make sure we're in it.
537        if self._writers and self._instance_name not in self._writers:
538            raise Exception("Tried to allocate stream ID on non-writer")
539
540        next_id = self._load_next_id_txn(txn)
541
542        txn.call_after(self._mark_id_as_finished, next_id)
543        txn.call_on_exception(self._mark_id_as_finished, next_id)
544
545        # Update the `stream_positions` table with newly updated stream
546        # ID (unless self._writers is not set in which case we don't
547        # bother, as nothing will read it).
548        #
549        # We only do this on the success path so that the persisted current
550        # position points to a persisted row with the correct instance name.
551        if self._writers:
552            txn.call_after(
553                run_as_background_process,
554                "MultiWriterIdGenerator._update_table",
555                self._db.runInteraction,
556                "MultiWriterIdGenerator._update_table",
557                self._update_stream_positions_table_txn,
558            )
559
560        return self._return_factor * next_id
561
562    def _mark_id_as_finished(self, next_id: int) -> None:
563        """The ID has finished being processed so we should advance the
564        current position if possible.
565        """
566
567        with self._lock:
568            self._unfinished_ids.discard(next_id)
569            self._finished_ids.add(next_id)
570
571            new_cur: Optional[int] = None
572
573            if self._unfinished_ids or self._in_flight_fetches:
574                # If there are unfinished IDs then the new position will be the
575                # largest finished ID strictly less than the minimum unfinished
576                # ID.
577
578                # The minimum unfinished ID needs to take account of both
579                # `_unfinished_ids` and `_in_flight_fetches`.
580                if self._unfinished_ids and self._in_flight_fetches:
581                    # `_in_flight_fetches` stores the maximum safe stream ID, so
582                    # we add one to make it equivalent to the minimum unsafe ID.
583                    min_unfinished = min(
584                        self._unfinished_ids[0], self._in_flight_fetches[0] + 1
585                    )
586                elif self._in_flight_fetches:
587                    min_unfinished = self._in_flight_fetches[0] + 1
588                else:
589                    min_unfinished = self._unfinished_ids[0]
590
591                finished = set()
592                for s in self._finished_ids:
593                    if s < min_unfinished:
594                        if new_cur is None or new_cur < s:
595                            new_cur = s
596                    else:
597                        finished.add(s)
598
599                # We clear these out since they're now all less than the new
600                # position.
601                self._finished_ids = finished
602            else:
603                # There are no unfinished IDs so the new position is simply the
604                # largest finished one.
605                new_cur = max(self._finished_ids)
606
607                # We clear these out since they're now all less than the new
608                # position.
609                self._finished_ids.clear()
610
611            if new_cur:
612                curr = self._current_positions.get(self._instance_name, 0)
613                self._current_positions[self._instance_name] = max(curr, new_cur)
614
615            self._add_persisted_position(next_id)
616
617    def get_current_token(self) -> int:
618        return self.get_persisted_upto_position()
619
620    def get_current_token_for_writer(self, instance_name: str) -> int:
621        # If we don't have an entry for the given instance name, we assume it's a
622        # new writer.
623        #
624        # For new writers we assume their initial position to be the current
625        # persisted up to position. This stops Synapse from doing a full table
626        # scan when a new writer announces itself over replication.
627        with self._lock:
628            return self._return_factor * self._current_positions.get(
629                instance_name, self._persisted_upto_position
630            )
631
632    def get_positions(self) -> Dict[str, int]:
633        """Get a copy of the current positon map.
634
635        Note that this won't necessarily include all configured writers if some
636        writers haven't written anything yet.
637        """
638
639        with self._lock:
640            return {
641                name: self._return_factor * i
642                for name, i in self._current_positions.items()
643            }
644
645    def advance(self, instance_name: str, new_id: int) -> None:
646        new_id *= self._return_factor
647
648        with self._lock:
649            self._current_positions[instance_name] = max(
650                new_id, self._current_positions.get(instance_name, 0)
651            )
652
653            self._max_seen_allocated_stream_id = max(
654                self._max_seen_allocated_stream_id, new_id
655            )
656
657            self._add_persisted_position(new_id)
658
659    def get_persisted_upto_position(self) -> int:
660        """Get the max position where all previous positions have been
661        persisted.
662
663        Note: In the worst case scenario this will be equal to the minimum
664        position across writers. This means that the returned position here can
665        lag if one writer doesn't write very often.
666        """
667
668        with self._lock:
669            return self._return_factor * self._persisted_upto_position
670
671    def _add_persisted_position(self, new_id: int) -> None:
672        """Record that we have persisted a position.
673
674        This is used to keep the `_current_positions` up to date.
675        """
676
677        # We require that the lock is locked by caller
678        assert self._lock.locked()
679
680        heapq.heappush(self._known_persisted_positions, new_id)
681
682        # If we're a writer and we don't have any active writes we update our
683        # current position to the latest position seen. This allows the instance
684        # to report a recent position when asked, rather than a potentially old
685        # one (if this instance hasn't written anything for a while).
686        our_current_position = self._current_positions.get(self._instance_name)
687        if (
688            our_current_position
689            and not self._unfinished_ids
690            and not self._in_flight_fetches
691        ):
692            self._current_positions[self._instance_name] = max(
693                our_current_position, new_id
694            )
695
696        # We move the current min position up if the minimum current positions
697        # of all instances is higher (since by definition all positions less
698        # that that have been persisted).
699        min_curr = min(self._current_positions.values(), default=0)
700        self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
701
702        # We now iterate through the seen positions, discarding those that are
703        # less than the current min positions, and incrementing the min position
704        # if its exactly one greater.
705        #
706        # This is also where we discard items from `_known_persisted_positions`
707        # (to ensure the list doesn't infinitely grow).
708        while self._known_persisted_positions:
709            if self._known_persisted_positions[0] <= self._persisted_upto_position:
710                heapq.heappop(self._known_persisted_positions)
711            elif (
712                self._known_persisted_positions[0] == self._persisted_upto_position + 1
713            ):
714                heapq.heappop(self._known_persisted_positions)
715                self._persisted_upto_position += 1
716            else:
717                # There was a gap in seen positions, so there is nothing more to
718                # do.
719                break
720
721    def _update_stream_positions_table_txn(self, txn: Cursor) -> None:
722        """Update the `stream_positions` table with newly persisted position."""
723
724        if not self._writers:
725            return
726
727        # We upsert the value, ensuring on conflict that we always increase the
728        # value (or decrease if stream goes backwards).
729        sql = """
730            INSERT INTO stream_positions (stream_name, instance_name, stream_id)
731            VALUES (?, ?, ?)
732            ON CONFLICT (stream_name, instance_name)
733            DO UPDATE SET
734                stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
735        """ % {
736            "agg": "GREATEST" if self._positive else "LEAST",
737        }
738
739        pos = (self.get_current_token_for_writer(self._instance_name),)
740        txn.execute(sql, (self._stream_name, self._instance_name, pos))
741
742
743@attr.s(frozen=True, auto_attribs=True)
744class _AsyncCtxManagerWrapper(Generic[T]):
745    """Helper class to convert a plain context manager to an async one.
746
747    This is mainly useful if you have a plain context manager but the interface
748    requires an async one.
749    """
750
751    inner: ContextManager[T]
752
753    async def __aenter__(self) -> T:
754        return self.inner.__enter__()
755
756    async def __aexit__(
757        self,
758        exc_type: Optional[Type[BaseException]],
759        exc: Optional[BaseException],
760        tb: Optional[TracebackType],
761    ) -> Optional[bool]:
762        return self.inner.__exit__(exc_type, exc, tb)
763
764
765@attr.s(slots=True)
766class _MultiWriterCtxManager:
767    """Async context manager returned by MultiWriterIdGenerator"""
768
769    id_gen = attr.ib(type=MultiWriterIdGenerator)
770    multiple_ids = attr.ib(type=Optional[int], default=None)
771    stream_ids = attr.ib(type=List[int], factory=list)
772
773    async def __aenter__(self) -> Union[int, List[int]]:
774        # It's safe to run this in autocommit mode as fetching values from a
775        # sequence ignores transaction semantics anyway.
776        self.stream_ids = await self.id_gen._db.runInteraction(
777            "_load_next_mult_id",
778            self.id_gen._load_next_mult_id_txn,
779            self.multiple_ids or 1,
780            db_autocommit=True,
781        )
782
783        if self.multiple_ids is None:
784            return self.stream_ids[0] * self.id_gen._return_factor
785        else:
786            return [i * self.id_gen._return_factor for i in self.stream_ids]
787
788    async def __aexit__(
789        self,
790        exc_type: Optional[Type[BaseException]],
791        exc: Optional[BaseException],
792        tb: Optional[TracebackType],
793    ) -> bool:
794        for i in self.stream_ids:
795            self.id_gen._mark_id_as_finished(i)
796
797        if exc_type is not None:
798            return False
799
800        # Update the `stream_positions` table with newly updated stream
801        # ID (unless self._writers is not set in which case we don't
802        # bother, as nothing will read it).
803        #
804        # We only do this on the success path so that the persisted current
805        # position points to a persisted row with the correct instance name.
806        #
807        # We do this in autocommit mode as a) the upsert works correctly outside
808        # transactions and b) reduces the amount of time the rows are locked
809        # for. If we don't do this then we'll often hit serialization errors due
810        # to the fact we default to REPEATABLE READ isolation levels.
811        if self.id_gen._writers:
812            await self.id_gen._db.runInteraction(
813                "MultiWriterIdGenerator._update_table",
814                self.id_gen._update_stream_positions_table_txn,
815                db_autocommit=True,
816            )
817
818        return False
819