1# Copyright 2014-2016 OpenMarket Ltd
2# Copyright 2017 Vector Creations Ltd
3# Copyright 2018-2019 New Vector Ltd
4# Copyright 2019 The Matrix.org Foundation C.I.C.
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#     http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18""" This module is responsible for getting events from the DB for pagination
19and event streaming.
20
21The order it returns events in depend on whether we are streaming forwards or
22are paginating backwards. We do this because we want to handle out of order
23messages nicely, while still returning them in the correct order when we
24paginate bacwards.
25
26This is implemented by keeping two ordering columns: stream_ordering and
27topological_ordering. Stream ordering is basically insertion/received order
28(except for events from backfill requests). The topological_ordering is a
29weak ordering of events based on the pdu graph.
30
31This means that we have to have two different types of tokens, depending on
32what sort order was used:
33    - stream tokens are of the form: "s%d", which maps directly to the column
34    - topological tokems: "t%d-%d", where the integers map to the topological
35      and stream ordering columns respectively.
36"""
37
38import logging
39from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple
40
41import attr
42from frozendict import frozendict
43
44from twisted.internet import defer
45
46from synapse.api.filtering import Filter
47from synapse.events import EventBase
48from synapse.logging.context import make_deferred_yieldable, run_in_background
49from synapse.storage._base import SQLBaseStore
50from synapse.storage.database import (
51    DatabasePool,
52    LoggingDatabaseConnection,
53    LoggingTransaction,
54    make_in_list_sql_clause,
55)
56from synapse.storage.databases.main.events_worker import EventsWorkerStore
57from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
58from synapse.storage.util.id_generators import MultiWriterIdGenerator
59from synapse.types import PersistedEventPosition, RoomStreamToken
60from synapse.util.caches.descriptors import cached
61from synapse.util.caches.stream_change_cache import StreamChangeCache
62
63if TYPE_CHECKING:
64    from synapse.server import HomeServer
65
66logger = logging.getLogger(__name__)
67
68
69MAX_STREAM_SIZE = 1000
70
71
72_STREAM_TOKEN = "stream"
73_TOPOLOGICAL_TOKEN = "topological"
74
75
76# Used as return values for pagination APIs
77@attr.s(slots=True, frozen=True, auto_attribs=True)
78class _EventDictReturn:
79    event_id: str
80    topological_ordering: Optional[int]
81    stream_ordering: int
82
83
84def generate_pagination_where_clause(
85    direction: str,
86    column_names: Tuple[str, str],
87    from_token: Optional[Tuple[Optional[int], int]],
88    to_token: Optional[Tuple[Optional[int], int]],
89    engine: BaseDatabaseEngine,
90) -> str:
91    """Creates an SQL expression to bound the columns by the pagination
92    tokens.
93
94    For example creates an SQL expression like:
95
96        (6, 7) >= (topological_ordering, stream_ordering)
97        AND (5, 3) < (topological_ordering, stream_ordering)
98
99    would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3).
100
101    Note that tokens are considered to be after the row they are in, e.g. if
102    a row A has a token T, then we consider A to be before T. This convention
103    is important when figuring out inequalities for the generated SQL, and
104    produces the following result:
105        - If paginating forwards then we exclude any rows matching the from
106          token, but include those that match the to token.
107        - If paginating backwards then we include any rows matching the from
108          token, but include those that match the to token.
109
110    Args:
111        direction: Whether we're paginating backwards("b") or forwards ("f").
112        column_names: The column names to bound. Must *not* be user defined as
113            these get inserted directly into the SQL statement without escapes.
114        from_token: The start point for the pagination. This is an exclusive
115            minimum bound if direction is "f", and an inclusive maximum bound if
116            direction is "b".
117        to_token: The endpoint point for the pagination. This is an inclusive
118            maximum bound if direction is "f", and an exclusive minimum bound if
119            direction is "b".
120        engine: The database engine to generate the clauses for
121
122    Returns:
123        The sql expression
124    """
125    assert direction in ("b", "f")
126
127    where_clause = []
128    if from_token:
129        where_clause.append(
130            _make_generic_sql_bound(
131                bound=">=" if direction == "b" else "<",
132                column_names=column_names,
133                values=from_token,
134                engine=engine,
135            )
136        )
137
138    if to_token:
139        where_clause.append(
140            _make_generic_sql_bound(
141                bound="<" if direction == "b" else ">=",
142                column_names=column_names,
143                values=to_token,
144                engine=engine,
145            )
146        )
147
148    return " AND ".join(where_clause)
149
150
151def _make_generic_sql_bound(
152    bound: str,
153    column_names: Tuple[str, str],
154    values: Tuple[Optional[int], int],
155    engine: BaseDatabaseEngine,
156) -> str:
157    """Create an SQL expression that bounds the given column names by the
158    values, e.g. create the equivalent of `(1, 2) < (col1, col2)`.
159
160    Only works with two columns.
161
162    Older versions of SQLite don't support that syntax so we have to expand it
163    out manually.
164
165    Args:
166        bound: The comparison operator to use. One of ">", "<", ">=",
167            "<=", where the values are on the left and columns on the right.
168        names: The column names. Must *not* be user defined
169            as these get inserted directly into the SQL statement without
170            escapes.
171        values: The values to bound the columns by. If
172            the first value is None then only creates a bound on the second
173            column.
174        engine: The database engine to generate the SQL for
175
176    Returns:
177        The SQL statement
178    """
179
180    assert bound in (">", "<", ">=", "<=")
181
182    name1, name2 = column_names
183    val1, val2 = values
184
185    if val1 is None:
186        val2 = int(val2)
187        return "(%d %s %s)" % (val2, bound, name2)
188
189    val1 = int(val1)
190    val2 = int(val2)
191
192    if isinstance(engine, PostgresEngine):
193        # Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
194        # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
195        # use the later form when running against postgres.
196        return "((%d,%d) %s (%s,%s))" % (val1, val2, bound, name1, name2)
197
198    # We want to generate queries of e.g. the form:
199    #
200    #   (val1 < name1 OR (val1 = name1 AND val2 <= name2))
201    #
202    # which is equivalent to (val1, val2) < (name1, name2)
203
204    return """(
205        {val1:d} {strict_bound} {name1}
206        OR ({val1:d} = {name1} AND {val2:d} {bound} {name2})
207    )""".format(
208        name1=name1,
209        val1=val1,
210        name2=name2,
211        val2=val2,
212        strict_bound=bound[0],  # The first bound must always be strict equality here
213        bound=bound,
214    )
215
216
217def _filter_results(
218    lower_token: Optional[RoomStreamToken],
219    upper_token: Optional[RoomStreamToken],
220    instance_name: str,
221    topological_ordering: int,
222    stream_ordering: int,
223) -> bool:
224    """Returns True if the event persisted by the given instance at the given
225    topological/stream_ordering falls between the two tokens (taking a None
226    token to mean unbounded).
227
228    Used to filter results from fetching events in the DB against the given
229    tokens. This is necessary to handle the case where the tokens include
230    position maps, which we handle by fetching more than necessary from the DB
231    and then filtering (rather than attempting to construct a complicated SQL
232    query).
233    """
234
235    event_historical_tuple = (
236        topological_ordering,
237        stream_ordering,
238    )
239
240    if lower_token:
241        if lower_token.topological is not None:
242            # If these are historical tokens we compare the `(topological, stream)`
243            # tuples.
244            if event_historical_tuple <= lower_token.as_historical_tuple():
245                return False
246
247        else:
248            # If these are live tokens we compare the stream ordering against the
249            # writers stream position.
250            if stream_ordering <= lower_token.get_stream_pos_for_instance(
251                instance_name
252            ):
253                return False
254
255    if upper_token:
256        if upper_token.topological is not None:
257            if upper_token.as_historical_tuple() < event_historical_tuple:
258                return False
259        else:
260            if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering:
261                return False
262
263    return True
264
265
266def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
267    # NB: This may create SQL clauses that don't optimise well (and we don't
268    # have indices on all possible clauses). E.g. it may create
269    # "room_id == X AND room_id != X", which postgres doesn't optimise.
270
271    if not event_filter:
272        return "", []
273
274    clauses = []
275    args = []
276
277    if event_filter.types:
278        clauses.append(
279            "(%s)" % " OR ".join("event.type = ?" for _ in event_filter.types)
280        )
281        args.extend(event_filter.types)
282
283    for typ in event_filter.not_types:
284        clauses.append("event.type != ?")
285        args.append(typ)
286
287    if event_filter.senders:
288        clauses.append(
289            "(%s)" % " OR ".join("event.sender = ?" for _ in event_filter.senders)
290        )
291        args.extend(event_filter.senders)
292
293    for sender in event_filter.not_senders:
294        clauses.append("event.sender != ?")
295        args.append(sender)
296
297    if event_filter.rooms:
298        clauses.append(
299            "(%s)" % " OR ".join("event.room_id = ?" for _ in event_filter.rooms)
300        )
301        args.extend(event_filter.rooms)
302
303    for room_id in event_filter.not_rooms:
304        clauses.append("event.room_id != ?")
305        args.append(room_id)
306
307    if event_filter.contains_url:
308        clauses.append("event.contains_url = ?")
309        args.append(event_filter.contains_url)
310
311    # We're only applying the "labels" filter on the database query, because applying the
312    # "not_labels" filter via a SQL query is non-trivial. Instead, we let
313    # event_filter.check_fields apply it, which is not as efficient but makes the
314    # implementation simpler.
315    if event_filter.labels:
316        clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels))
317        args.extend(event_filter.labels)
318
319    # Filter on relation_senders / relation types from the joined tables.
320    if event_filter.relation_senders:
321        clauses.append(
322            "(%s)"
323            % " OR ".join(
324                "related_event.sender = ?" for _ in event_filter.relation_senders
325            )
326        )
327        args.extend(event_filter.relation_senders)
328
329    if event_filter.relation_types:
330        clauses.append(
331            "(%s)"
332            % " OR ".join("relation_type = ?" for _ in event_filter.relation_types)
333        )
334        args.extend(event_filter.relation_types)
335
336    return " AND ".join(clauses), args
337
338
339class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
340    def __init__(
341        self,
342        database: DatabasePool,
343        db_conn: LoggingDatabaseConnection,
344        hs: "HomeServer",
345    ):
346        super().__init__(database, db_conn, hs)
347
348        self._instance_name = hs.get_instance_name()
349        self._send_federation = hs.should_send_federation()
350        self._federation_shard_config = hs.config.worker.federation_shard_config
351
352        # If we're a process that sends federation we may need to reset the
353        # `federation_stream_position` table to match the current sharding
354        # config. We don't do this now as otherwise two processes could conflict
355        # during startup which would cause one to die.
356        self._need_to_reset_federation_stream_positions = self._send_federation
357
358        events_max = self.get_room_max_stream_ordering()
359        event_cache_prefill, min_event_val = self.db_pool.get_cache_dict(
360            db_conn,
361            "events",
362            entity_column="room_id",
363            stream_column="stream_ordering",
364            max_value=events_max,
365        )
366        self._events_stream_cache = StreamChangeCache(
367            "EventsRoomStreamChangeCache",
368            min_event_val,
369            prefilled_cache=event_cache_prefill,
370        )
371        self._membership_stream_cache = StreamChangeCache(
372            "MembershipStreamChangeCache", events_max
373        )
374
375        self._stream_order_on_start = self.get_room_max_stream_ordering()
376
377    def get_room_max_stream_ordering(self) -> int:
378        """Get the stream_ordering of regular events that we have committed up to
379
380        Returns the maximum stream id such that all stream ids less than or
381        equal to it have been successfully persisted.
382        """
383        return self._stream_id_gen.get_current_token()
384
385    def get_room_min_stream_ordering(self) -> int:
386        """Get the stream_ordering of backfilled events that we have committed up to
387
388        Backfilled events use *negative* stream orderings, so this returns the
389        minimum negative stream id such that all stream ids greater than or
390        equal to it have been successfully persisted.
391        """
392        return self._backfill_id_gen.get_current_token()
393
394    def get_room_max_token(self) -> RoomStreamToken:
395        """Get a `RoomStreamToken` that marks the current maximum persisted
396        position of the events stream. Useful to get a token that represents
397        "now".
398
399        The token returned is a "live" token that may have an instance_map
400        component.
401        """
402
403        min_pos = self._stream_id_gen.get_current_token()
404
405        positions = {}
406        if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
407            # The `min_pos` is the minimum position that we know all instances
408            # have finished persisting to, so we only care about instances whose
409            # positions are ahead of that. (Instance positions can be behind the
410            # min position as there are times we can work out that the minimum
411            # position is ahead of the naive minimum across all current
412            # positions. See MultiWriterIdGenerator for details)
413            positions = {
414                i: p
415                for i, p in self._stream_id_gen.get_positions().items()
416                if p > min_pos
417            }
418
419        return RoomStreamToken(None, min_pos, frozendict(positions))
420
421    async def get_room_events_stream_for_rooms(
422        self,
423        room_ids: Collection[str],
424        from_key: RoomStreamToken,
425        to_key: RoomStreamToken,
426        limit: int = 0,
427        order: str = "DESC",
428    ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]:
429        """Get new room events in stream ordering since `from_key`.
430
431        Args:
432            room_ids
433            from_key: Token from which no events are returned before
434            to_key: Token from which no events are returned after. (This
435                is typically the current stream token)
436            limit: Maximum number of events to return
437            order: Either "DESC" or "ASC". Determines which events are
438                returned when the result is limited. If "DESC" then the most
439                recent `limit` events are returned, otherwise returns the
440                oldest `limit` events.
441
442        Returns:
443            A map from room id to a tuple containing:
444                - list of recent events in the room
445                - stream ordering key for the start of the chunk of events returned.
446        """
447        room_ids = self._events_stream_cache.get_entities_changed(
448            room_ids, from_key.stream
449        )
450
451        if not room_ids:
452            return {}
453
454        results = {}
455        room_ids = list(room_ids)
456        for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)):
457            res = await make_deferred_yieldable(
458                defer.gatherResults(
459                    [
460                        run_in_background(
461                            self.get_room_events_stream_for_room,
462                            room_id,
463                            from_key,
464                            to_key,
465                            limit,
466                            order=order,
467                        )
468                        for room_id in rm_ids
469                    ],
470                    consumeErrors=True,
471                )
472            )
473            results.update(dict(zip(rm_ids, res)))
474
475        return results
476
477    def get_rooms_that_changed(
478        self, room_ids: Collection[str], from_key: RoomStreamToken
479    ) -> Set[str]:
480        """Given a list of rooms and a token, return rooms where there may have
481        been changes.
482        """
483        from_id = from_key.stream
484        return {
485            room_id
486            for room_id in room_ids
487            if self._events_stream_cache.has_entity_changed(room_id, from_id)
488        }
489
490    async def get_room_events_stream_for_room(
491        self,
492        room_id: str,
493        from_key: RoomStreamToken,
494        to_key: RoomStreamToken,
495        limit: int = 0,
496        order: str = "DESC",
497    ) -> Tuple[List[EventBase], RoomStreamToken]:
498        """Get new room events in stream ordering since `from_key`.
499
500        Args:
501            room_id
502            from_key: Token from which no events are returned before
503            to_key: Token from which no events are returned after. (This
504                is typically the current stream token)
505            limit: Maximum number of events to return
506            order: Either "DESC" or "ASC". Determines which events are
507                returned when the result is limited. If "DESC" then the most
508                recent `limit` events are returned, otherwise returns the
509                oldest `limit` events.
510
511        Returns:
512            The list of events (in ascending stream order) and the token from the start
513            of the chunk of events returned.
514        """
515        if from_key == to_key:
516            return [], from_key
517
518        has_changed = self._events_stream_cache.has_entity_changed(
519            room_id, from_key.stream
520        )
521
522        if not has_changed:
523            return [], from_key
524
525        def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
526            # To handle tokens with a non-empty instance_map we fetch more
527            # results than necessary and then filter down
528            min_from_id = from_key.stream
529            max_to_id = to_key.get_max_stream_pos()
530
531            sql = """
532                SELECT event_id, instance_name, topological_ordering, stream_ordering
533                FROM events
534                WHERE
535                    room_id = ?
536                    AND not outlier
537                    AND stream_ordering > ? AND stream_ordering <= ?
538                ORDER BY stream_ordering %s LIMIT ?
539            """ % (
540                order,
541            )
542            txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit))
543
544            rows = [
545                _EventDictReturn(event_id, None, stream_ordering)
546                for event_id, instance_name, topological_ordering, stream_ordering in txn
547                if _filter_results(
548                    from_key,
549                    to_key,
550                    instance_name,
551                    topological_ordering,
552                    stream_ordering,
553                )
554            ][:limit]
555            return rows
556
557        rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
558
559        ret = await self.get_events_as_list(
560            [r.event_id for r in rows], get_prev_content=True
561        )
562
563        self._set_before_and_after(ret, rows, topo_order=False)
564
565        if order.lower() == "desc":
566            ret.reverse()
567
568        if rows:
569            key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
570        else:
571            # Assume we didn't get anything because there was nothing to
572            # get.
573            key = from_key
574
575        return ret, key
576
577    async def get_membership_changes_for_user(
578        self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
579    ) -> List[EventBase]:
580        """Fetch membership events for a given user.
581
582        All such events whose stream ordering `s` lies in the range
583        `from_key < s <= to_key` are returned. Events are ordered by ascending stream
584        order.
585        """
586        # Start by ruling out cases where a DB query is not necessary.
587        if from_key == to_key:
588            return []
589
590        if from_key:
591            has_changed = self._membership_stream_cache.has_entity_changed(
592                user_id, int(from_key.stream)
593            )
594            if not has_changed:
595                return []
596
597        def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
598            # To handle tokens with a non-empty instance_map we fetch more
599            # results than necessary and then filter down
600            min_from_id = from_key.stream
601            max_to_id = to_key.get_max_stream_pos()
602
603            sql = """
604                SELECT m.event_id, instance_name, topological_ordering, stream_ordering
605                FROM events AS e, room_memberships AS m
606                WHERE e.event_id = m.event_id
607                    AND m.user_id = ?
608                    AND e.stream_ordering > ? AND e.stream_ordering <= ?
609                ORDER BY e.stream_ordering ASC
610            """
611            txn.execute(
612                sql,
613                (
614                    user_id,
615                    min_from_id,
616                    max_to_id,
617                ),
618            )
619
620            rows = [
621                _EventDictReturn(event_id, None, stream_ordering)
622                for event_id, instance_name, topological_ordering, stream_ordering in txn
623                if _filter_results(
624                    from_key,
625                    to_key,
626                    instance_name,
627                    topological_ordering,
628                    stream_ordering,
629                )
630            ]
631
632            return rows
633
634        rows = await self.db_pool.runInteraction("get_membership_changes_for_user", f)
635
636        ret = await self.get_events_as_list(
637            [r.event_id for r in rows], get_prev_content=True
638        )
639
640        self._set_before_and_after(ret, rows, topo_order=False)
641
642        return ret
643
644    async def get_recent_events_for_room(
645        self, room_id: str, limit: int, end_token: RoomStreamToken
646    ) -> Tuple[List[EventBase], RoomStreamToken]:
647        """Get the most recent events in the room in topological ordering.
648
649        Args:
650            room_id
651            limit
652            end_token: The stream token representing now.
653
654        Returns:
655            A list of events and a token pointing to the start of the returned
656            events. The events returned are in ascending topological order.
657        """
658
659        rows, token = await self.get_recent_event_ids_for_room(
660            room_id, limit, end_token
661        )
662
663        events = await self.get_events_as_list(
664            [r.event_id for r in rows], get_prev_content=True
665        )
666
667        self._set_before_and_after(events, rows)
668
669        return events, token
670
671    async def get_recent_event_ids_for_room(
672        self, room_id: str, limit: int, end_token: RoomStreamToken
673    ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
674        """Get the most recent events in the room in topological ordering.
675
676        Args:
677            room_id
678            limit
679            end_token: The stream token representing now.
680
681        Returns:
682            A list of _EventDictReturn and a token pointing to the start of the
683            returned events. The events returned are in ascending order.
684        """
685        # Allow a zero limit here, and no-op.
686        if limit == 0:
687            return [], end_token
688
689        rows, token = await self.db_pool.runInteraction(
690            "get_recent_event_ids_for_room",
691            self._paginate_room_events_txn,
692            room_id,
693            from_token=end_token,
694            limit=limit,
695        )
696
697        # We want to return the results in ascending order.
698        rows.reverse()
699
700        return rows, token
701
702    async def get_room_event_before_stream_ordering(
703        self, room_id: str, stream_ordering: int
704    ) -> Optional[Tuple[int, int, str]]:
705        """Gets details of the first event in a room at or before a stream ordering
706
707        Args:
708            room_id:
709            stream_ordering:
710
711        Returns:
712            A tuple of (stream ordering, topological ordering, event_id)
713        """
714
715        def _f(txn):
716            sql = (
717                "SELECT stream_ordering, topological_ordering, event_id"
718                " FROM events"
719                " WHERE room_id = ? AND stream_ordering <= ?"
720                " AND NOT outlier"
721                " ORDER BY stream_ordering DESC"
722                " LIMIT 1"
723            )
724            txn.execute(sql, (room_id, stream_ordering))
725            return txn.fetchone()
726
727        return await self.db_pool.runInteraction(
728            "get_room_event_before_stream_ordering", _f
729        )
730
731    async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str:
732        """Returns the current token for rooms stream.
733
734        By default, it returns the current global stream token. Specifying a
735        `room_id` causes it to return the current room specific topological
736        token.
737        """
738        token = self.get_room_max_stream_ordering()
739        if room_id is None:
740            return "s%d" % (token,)
741        else:
742            topo = await self.db_pool.runInteraction(
743                "_get_max_topological_txn", self._get_max_topological_txn, room_id
744            )
745            return "t%d-%d" % (topo, token)
746
747    def get_stream_id_for_event_txn(
748        self,
749        txn: LoggingTransaction,
750        event_id: str,
751        allow_none=False,
752    ) -> int:
753        return self.db_pool.simple_select_one_onecol_txn(
754            txn=txn,
755            table="events",
756            keyvalues={"event_id": event_id},
757            retcol="stream_ordering",
758            allow_none=allow_none,
759        )
760
761    async def get_position_for_event(self, event_id: str) -> PersistedEventPosition:
762        """Get the persisted position for an event"""
763        row = await self.db_pool.simple_select_one(
764            table="events",
765            keyvalues={"event_id": event_id},
766            retcols=("stream_ordering", "instance_name"),
767            desc="get_position_for_event",
768        )
769
770        return PersistedEventPosition(
771            row["instance_name"] or "master", row["stream_ordering"]
772        )
773
774    async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken:
775        """The stream token for an event
776        Args:
777            event_id: The id of the event to look up a stream token for.
778        Raises:
779            StoreError if the event wasn't in the database.
780        Returns:
781            A `RoomStreamToken` topological token.
782        """
783        row = await self.db_pool.simple_select_one(
784            table="events",
785            keyvalues={"event_id": event_id},
786            retcols=("stream_ordering", "topological_ordering"),
787            desc="get_topological_token_for_event",
788        )
789        return RoomStreamToken(row["topological_ordering"], row["stream_ordering"])
790
791    async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
792        """Gets the topological token in a room after or at the given stream
793        ordering.
794
795        Args:
796            room_id
797            stream_key
798        """
799        sql = (
800            "SELECT coalesce(MIN(topological_ordering), 0) FROM events"
801            " WHERE room_id = ? AND stream_ordering >= ?"
802        )
803        row = await self.db_pool.execute(
804            "get_current_topological_token", None, sql, room_id, stream_key
805        )
806        return row[0][0] if row else 0
807
808    def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int:
809        txn.execute(
810            "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
811            (room_id,),
812        )
813
814        rows = txn.fetchall()
815        return rows[0][0] if rows else 0
816
817    @staticmethod
818    def _set_before_and_after(
819        events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
820    ):
821        """Inserts ordering information to events' internal metadata from
822        the DB rows.
823
824        Args:
825            events
826            rows
827            topo_order: Whether the events were ordered topologically or by stream
828                ordering. If true then all rows should have a non null
829                topological_ordering.
830        """
831        for event, row in zip(events, rows):
832            stream = row.stream_ordering
833            if topo_order and row.topological_ordering:
834                topo: Optional[int] = row.topological_ordering
835            else:
836                topo = None
837            internal = event.internal_metadata
838            internal.before = RoomStreamToken(topo, stream - 1)
839            internal.after = RoomStreamToken(topo, stream)
840            internal.order = (int(topo) if topo else 0, int(stream))
841
842    async def get_events_around(
843        self,
844        room_id: str,
845        event_id: str,
846        before_limit: int,
847        after_limit: int,
848        event_filter: Optional[Filter] = None,
849    ) -> dict:
850        """Retrieve events and pagination tokens around a given event in a
851        room.
852        """
853
854        results = await self.db_pool.runInteraction(
855            "get_events_around",
856            self._get_events_around_txn,
857            room_id,
858            event_id,
859            before_limit,
860            after_limit,
861            event_filter,
862        )
863
864        events_before = await self.get_events_as_list(
865            list(results["before"]["event_ids"]), get_prev_content=True
866        )
867
868        events_after = await self.get_events_as_list(
869            list(results["after"]["event_ids"]), get_prev_content=True
870        )
871
872        return {
873            "events_before": events_before,
874            "events_after": events_after,
875            "start": results["before"]["token"],
876            "end": results["after"]["token"],
877        }
878
879    def _get_events_around_txn(
880        self,
881        txn: LoggingTransaction,
882        room_id: str,
883        event_id: str,
884        before_limit: int,
885        after_limit: int,
886        event_filter: Optional[Filter],
887    ) -> dict:
888        """Retrieves event_ids and pagination tokens around a given event in a
889        room.
890
891        Args:
892            room_id
893            event_id
894            before_limit
895            after_limit
896            event_filter
897
898        Returns:
899            dict
900        """
901
902        results = self.db_pool.simple_select_one_txn(
903            txn,
904            "events",
905            keyvalues={"event_id": event_id, "room_id": room_id},
906            retcols=["stream_ordering", "topological_ordering"],
907        )
908
909        # This cannot happen as `allow_none=False`.
910        assert results is not None
911
912        # Paginating backwards includes the event at the token, but paginating
913        # forward doesn't.
914        before_token = RoomStreamToken(
915            results["topological_ordering"] - 1, results["stream_ordering"]
916        )
917
918        after_token = RoomStreamToken(
919            results["topological_ordering"], results["stream_ordering"]
920        )
921
922        rows, start_token = self._paginate_room_events_txn(
923            txn,
924            room_id,
925            before_token,
926            direction="b",
927            limit=before_limit,
928            event_filter=event_filter,
929        )
930        events_before = [r.event_id for r in rows]
931
932        rows, end_token = self._paginate_room_events_txn(
933            txn,
934            room_id,
935            after_token,
936            direction="f",
937            limit=after_limit,
938            event_filter=event_filter,
939        )
940        events_after = [r.event_id for r in rows]
941
942        return {
943            "before": {"event_ids": events_before, "token": start_token},
944            "after": {"event_ids": events_after, "token": end_token},
945        }
946
947    async def get_all_new_events_stream(
948        self, from_id: int, current_id: int, limit: int
949    ) -> Tuple[int, List[EventBase]]:
950        """Get all new events
951
952        Returns all events with from_id < stream_ordering <= current_id.
953
954        Args:
955            from_id:  the stream_ordering of the last event we processed
956            current_id:  the stream_ordering of the most recently processed event
957            limit: the maximum number of events to return
958
959        Returns:
960            A tuple of (next_id, events), where `next_id` is the next value to
961            pass as `from_id` (it will either be the stream_ordering of the
962            last returned event, or, if fewer than `limit` events were found,
963            the `current_id`).
964        """
965
966        def get_all_new_events_stream_txn(txn):
967            sql = (
968                "SELECT e.stream_ordering, e.event_id"
969                " FROM events AS e"
970                " WHERE"
971                " ? < e.stream_ordering AND e.stream_ordering <= ?"
972                " ORDER BY e.stream_ordering ASC"
973                " LIMIT ?"
974            )
975
976            txn.execute(sql, (from_id, current_id, limit))
977            rows = txn.fetchall()
978
979            upper_bound = current_id
980            if len(rows) == limit:
981                upper_bound = rows[-1][0]
982
983            return upper_bound, [row[1] for row in rows]
984
985        upper_bound, event_ids = await self.db_pool.runInteraction(
986            "get_all_new_events_stream", get_all_new_events_stream_txn
987        )
988
989        events = await self.get_events_as_list(event_ids)
990
991        return upper_bound, events
992
993    async def get_federation_out_pos(self, typ: str) -> int:
994        if self._need_to_reset_federation_stream_positions:
995            await self.db_pool.runInteraction(
996                "_reset_federation_positions_txn", self._reset_federation_positions_txn
997            )
998            self._need_to_reset_federation_stream_positions = False
999
1000        return await self.db_pool.simple_select_one_onecol(
1001            table="federation_stream_position",
1002            retcol="stream_id",
1003            keyvalues={"type": typ, "instance_name": self._instance_name},
1004            desc="get_federation_out_pos",
1005        )
1006
1007    async def update_federation_out_pos(self, typ: str, stream_id: int) -> None:
1008        if self._need_to_reset_federation_stream_positions:
1009            await self.db_pool.runInteraction(
1010                "_reset_federation_positions_txn", self._reset_federation_positions_txn
1011            )
1012            self._need_to_reset_federation_stream_positions = False
1013
1014        await self.db_pool.simple_update_one(
1015            table="federation_stream_position",
1016            keyvalues={"type": typ, "instance_name": self._instance_name},
1017            updatevalues={"stream_id": stream_id},
1018            desc="update_federation_out_pos",
1019        )
1020
1021    def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None:
1022        """Fiddles with the `federation_stream_position` table to make it match
1023        the configured federation sender instances during start up.
1024        """
1025
1026        # The federation sender instances may have changed, so we need to
1027        # massage the `federation_stream_position` table to have a row per type
1028        # per instance sending federation. If there is a mismatch we update the
1029        # table with the correct rows using the *minimum* stream ID seen. This
1030        # may result in resending of events/EDUs to remote servers, but that is
1031        # preferable to dropping them.
1032
1033        if not self._send_federation:
1034            return
1035
1036        # Pull out the configured instances. If we don't have a shard config then
1037        # we assume that we're the only instance sending.
1038        configured_instances = self._federation_shard_config.instances
1039        if not configured_instances:
1040            configured_instances = [self._instance_name]
1041        elif self._instance_name not in configured_instances:
1042            return
1043
1044        instances_in_table = self.db_pool.simple_select_onecol_txn(
1045            txn,
1046            table="federation_stream_position",
1047            keyvalues={},
1048            retcol="instance_name",
1049        )
1050
1051        if set(instances_in_table) == set(configured_instances):
1052            # Nothing to do
1053            return
1054
1055        sql = """
1056            SELECT type, MIN(stream_id) FROM federation_stream_position
1057            GROUP BY type
1058        """
1059        txn.execute(sql)
1060        min_positions = {typ: pos for typ, pos in txn}  # Map from type -> min position
1061
1062        # Ensure we do actually have some values here
1063        assert set(min_positions) == {"federation", "events"}
1064
1065        sql = """
1066            DELETE FROM federation_stream_position
1067            WHERE NOT (%s)
1068        """
1069        clause, args = make_in_list_sql_clause(
1070            txn.database_engine, "instance_name", configured_instances
1071        )
1072        txn.execute(sql % (clause,), args)
1073
1074        for typ, stream_id in min_positions.items():
1075            self.db_pool.simple_upsert_txn(
1076                txn,
1077                table="federation_stream_position",
1078                keyvalues={"type": typ, "instance_name": self._instance_name},
1079                values={"stream_id": stream_id},
1080            )
1081
1082    def has_room_changed_since(self, room_id: str, stream_id: int) -> bool:
1083        return self._events_stream_cache.has_entity_changed(room_id, stream_id)
1084
1085    def _paginate_room_events_txn(
1086        self,
1087        txn: LoggingTransaction,
1088        room_id: str,
1089        from_token: RoomStreamToken,
1090        to_token: Optional[RoomStreamToken] = None,
1091        direction: str = "b",
1092        limit: int = -1,
1093        event_filter: Optional[Filter] = None,
1094    ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
1095        """Returns list of events before or after a given token.
1096
1097        Args:
1098            txn
1099            room_id
1100            from_token: The token used to stream from
1101            to_token: A token which if given limits the results to only those before
1102            direction: Either 'b' or 'f' to indicate whether we are paginating
1103                forwards or backwards from `from_key`.
1104            limit: The maximum number of events to return.
1105            event_filter: If provided filters the events to
1106                those that match the filter.
1107
1108        Returns:
1109            A list of _EventDictReturn and a token that points to the end of the
1110            result set. If no events are returned then the end of the stream has
1111            been reached (i.e. there are no events between `from_token` and
1112            `to_token`), or `limit` is zero.
1113        """
1114
1115        assert int(limit) >= 0
1116
1117        # Tokens really represent positions between elements, but we use
1118        # the convention of pointing to the event before the gap. Hence
1119        # we have a bit of asymmetry when it comes to equalities.
1120        args = [False, room_id]
1121        if direction == "b":
1122            order = "DESC"
1123        else:
1124            order = "ASC"
1125
1126        # The bounds for the stream tokens are complicated by the fact
1127        # that we need to handle the instance_map part of the tokens. We do this
1128        # by fetching all events between the min stream token and the maximum
1129        # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
1130        # then filtering the results.
1131        if from_token.topological is not None:
1132            from_bound: Tuple[Optional[int], int] = from_token.as_historical_tuple()
1133        elif direction == "b":
1134            from_bound = (
1135                None,
1136                from_token.get_max_stream_pos(),
1137            )
1138        else:
1139            from_bound = (
1140                None,
1141                from_token.stream,
1142            )
1143
1144        to_bound: Optional[Tuple[Optional[int], int]] = None
1145        if to_token:
1146            if to_token.topological is not None:
1147                to_bound = to_token.as_historical_tuple()
1148            elif direction == "b":
1149                to_bound = (
1150                    None,
1151                    to_token.stream,
1152                )
1153            else:
1154                to_bound = (
1155                    None,
1156                    to_token.get_max_stream_pos(),
1157                )
1158
1159        bounds = generate_pagination_where_clause(
1160            direction=direction,
1161            column_names=("event.topological_ordering", "event.stream_ordering"),
1162            from_token=from_bound,
1163            to_token=to_bound,
1164            engine=self.database_engine,
1165        )
1166
1167        filter_clause, filter_args = filter_to_clause(event_filter)
1168
1169        if filter_clause:
1170            bounds += " AND " + filter_clause
1171            args.extend(filter_args)
1172
1173        # We fetch more events as we'll filter the result set
1174        args.append(int(limit) * 2)
1175
1176        select_keywords = "SELECT"
1177        join_clause = ""
1178        # Using DISTINCT in this SELECT query is quite expensive, because it
1179        # requires the engine to sort on the entire (not limited) result set,
1180        # i.e. the entire events table. Only use it in scenarios that could result
1181        # in the same event ID occurring multiple times in the results.
1182        needs_distinct = False
1183        if event_filter and event_filter.labels:
1184            # If we're not filtering on a label, then joining on event_labels will
1185            # return as many row for a single event as the number of labels it has. To
1186            # avoid this, only join if we're filtering on at least one label.
1187            join_clause += """
1188                LEFT JOIN event_labels
1189                USING (event_id, room_id, topological_ordering)
1190            """
1191            if len(event_filter.labels) > 1:
1192                # Multiple labels could cause the same event to appear multiple times.
1193                needs_distinct = True
1194
1195        # If there is a filter on relation_senders and relation_types join to the
1196        # relations table.
1197        if event_filter and (
1198            event_filter.relation_senders or event_filter.relation_types
1199        ):
1200            # Filtering by relations could cause the same event to appear multiple
1201            # times (since there's no limit on the number of relations to an event).
1202            needs_distinct = True
1203            join_clause += """
1204                LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id)
1205            """
1206            if event_filter.relation_senders:
1207                join_clause += """
1208                    LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id)
1209                """
1210
1211        if needs_distinct:
1212            select_keywords += " DISTINCT"
1213
1214        sql = """
1215            %(select_keywords)s
1216                event.event_id, event.instance_name,
1217                event.topological_ordering, event.stream_ordering
1218            FROM events AS event
1219            %(join_clause)s
1220            WHERE event.outlier = ? AND event.room_id = ? AND %(bounds)s
1221            ORDER BY event.topological_ordering %(order)s,
1222            event.stream_ordering %(order)s LIMIT ?
1223        """ % {
1224            "select_keywords": select_keywords,
1225            "join_clause": join_clause,
1226            "bounds": bounds,
1227            "order": order,
1228        }
1229
1230        txn.execute(sql, args)
1231
1232        # Filter the result set.
1233        rows = [
1234            _EventDictReturn(event_id, topological_ordering, stream_ordering)
1235            for event_id, instance_name, topological_ordering, stream_ordering in txn
1236            if _filter_results(
1237                lower_token=to_token if direction == "b" else from_token,
1238                upper_token=from_token if direction == "b" else to_token,
1239                instance_name=instance_name,
1240                topological_ordering=topological_ordering,
1241                stream_ordering=stream_ordering,
1242            )
1243        ][:limit]
1244
1245        if rows:
1246            topo = rows[-1].topological_ordering
1247            toke = rows[-1].stream_ordering
1248            if direction == "b":
1249                # Tokens are positions between events.
1250                # This token points *after* the last event in the chunk.
1251                # We need it to point to the event before it in the chunk
1252                # when we are going backwards so we subtract one from the
1253                # stream part.
1254                toke -= 1
1255            next_token = RoomStreamToken(topo, toke)
1256        else:
1257            # TODO (erikj): We should work out what to do here instead.
1258            next_token = to_token if to_token else from_token
1259
1260        return rows, next_token
1261
1262    async def paginate_room_events(
1263        self,
1264        room_id: str,
1265        from_key: RoomStreamToken,
1266        to_key: Optional[RoomStreamToken] = None,
1267        direction: str = "b",
1268        limit: int = -1,
1269        event_filter: Optional[Filter] = None,
1270    ) -> Tuple[List[EventBase], RoomStreamToken]:
1271        """Returns list of events before or after a given token.
1272
1273        Args:
1274            room_id
1275            from_key: The token used to stream from
1276            to_key: A token which if given limits the results to only those before
1277            direction: Either 'b' or 'f' to indicate whether we are paginating
1278                forwards or backwards from `from_key`.
1279            limit: The maximum number of events to return.
1280            event_filter: If provided filters the events to those that match the filter.
1281
1282        Returns:
1283            The results as a list of events and a token that points to the end
1284            of the result set. If no events are returned then the end of the
1285            stream has been reached (i.e. there are no events between `from_key`
1286            and `to_key`).
1287        """
1288
1289        rows, token = await self.db_pool.runInteraction(
1290            "paginate_room_events",
1291            self._paginate_room_events_txn,
1292            room_id,
1293            from_key,
1294            to_key,
1295            direction,
1296            limit,
1297            event_filter,
1298        )
1299
1300        events = await self.get_events_as_list(
1301            [r.event_id for r in rows], get_prev_content=True
1302        )
1303
1304        self._set_before_and_after(events, rows)
1305
1306        return events, token
1307
1308    @cached()
1309    async def get_id_for_instance(self, instance_name: str) -> int:
1310        """Get a unique, immutable ID that corresponds to the given Synapse worker instance."""
1311
1312        def _get_id_for_instance_txn(txn):
1313            instance_id = self.db_pool.simple_select_one_onecol_txn(
1314                txn,
1315                table="instance_map",
1316                keyvalues={"instance_name": instance_name},
1317                retcol="instance_id",
1318                allow_none=True,
1319            )
1320            if instance_id is not None:
1321                return instance_id
1322
1323            # If we don't have an entry upsert one.
1324            #
1325            # We could do this before the first check, and rely on the cache for
1326            # efficiency, but each UPSERT causes the next ID to increment which
1327            # can quickly bloat the size of the generated IDs for new instances.
1328            self.db_pool.simple_upsert_txn(
1329                txn,
1330                table="instance_map",
1331                keyvalues={"instance_name": instance_name},
1332                values={},
1333            )
1334
1335            return self.db_pool.simple_select_one_onecol_txn(
1336                txn,
1337                table="instance_map",
1338                keyvalues={"instance_name": instance_name},
1339                retcol="instance_id",
1340            )
1341
1342        return await self.db_pool.runInteraction(
1343            "get_id_for_instance", _get_id_for_instance_txn
1344        )
1345
1346    @cached()
1347    async def get_name_from_instance_id(self, instance_id: int) -> str:
1348        """Get the instance name from an ID previously returned by
1349        `get_id_for_instance`.
1350        """
1351
1352        return await self.db_pool.simple_select_one_onecol(
1353            table="instance_map",
1354            keyvalues={"instance_id": instance_id},
1355            retcol="instance_name",
1356            desc="get_name_from_instance_id",
1357        )
1358