1# Copyright 2014-2016 OpenMarket Ltd
2# Copyright 2018 New Vector Ltd
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import logging
17from typing import (
18    TYPE_CHECKING,
19    Any,
20    Collection,
21    Dict,
22    Iterable,
23    List,
24    Optional,
25    Set,
26    Tuple,
27)
28
29from twisted.internet import defer
30
31from synapse.api.constants import ReceiptTypes
32from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
33from synapse.replication.tcp.streams import ReceiptsStream
34from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
35from synapse.storage.database import (
36    DatabasePool,
37    LoggingDatabaseConnection,
38    LoggingTransaction,
39)
40from synapse.storage.engines import PostgresEngine
41from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
42from synapse.types import JsonDict
43from synapse.util import json_encoder
44from synapse.util.caches.descriptors import cached, cachedList
45from synapse.util.caches.stream_change_cache import StreamChangeCache
46
47if TYPE_CHECKING:
48    from synapse.server import HomeServer
49
50logger = logging.getLogger(__name__)
51
52
53class ReceiptsWorkerStore(SQLBaseStore):
54    def __init__(
55        self,
56        database: DatabasePool,
57        db_conn: LoggingDatabaseConnection,
58        hs: "HomeServer",
59    ):
60        self._instance_name = hs.get_instance_name()
61
62        if isinstance(database.engine, PostgresEngine):
63            self._can_write_to_receipts = (
64                self._instance_name in hs.config.worker.writers.receipts
65            )
66
67            self._receipts_id_gen = MultiWriterIdGenerator(
68                db_conn=db_conn,
69                db=database,
70                stream_name="receipts",
71                instance_name=self._instance_name,
72                tables=[("receipts_linearized", "instance_name", "stream_id")],
73                sequence_name="receipts_sequence",
74                writers=hs.config.worker.writers.receipts,
75            )
76        else:
77            self._can_write_to_receipts = True
78
79            # We shouldn't be running in worker mode with SQLite, but its useful
80            # to support it for unit tests.
81            #
82            # If this process is the writer than we need to use
83            # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
84            # updated over replication. (Multiple writers are not supported for
85            # SQLite).
86            if hs.get_instance_name() in hs.config.worker.writers.receipts:
87                self._receipts_id_gen = StreamIdGenerator(
88                    db_conn, "receipts_linearized", "stream_id"
89                )
90            else:
91                self._receipts_id_gen = SlavedIdTracker(
92                    db_conn, "receipts_linearized", "stream_id"
93                )
94
95        super().__init__(database, db_conn, hs)
96
97        self._receipts_stream_cache = StreamChangeCache(
98            "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
99        )
100
101    def get_max_receipt_stream_id(self) -> int:
102        """Get the current max stream ID for receipts stream"""
103        return self._receipts_id_gen.get_current_token()
104
105    @cached()
106    async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]:
107        receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ)
108        return {r["user_id"] for r in receipts}
109
110    @cached(num_args=2)
111    async def get_receipts_for_room(
112        self, room_id: str, receipt_type: str
113    ) -> List[Dict[str, Any]]:
114        return await self.db_pool.simple_select_list(
115            table="receipts_linearized",
116            keyvalues={"room_id": room_id, "receipt_type": receipt_type},
117            retcols=("user_id", "event_id"),
118            desc="get_receipts_for_room",
119        )
120
121    @cached(num_args=3)
122    async def get_last_receipt_event_id_for_user(
123        self, user_id: str, room_id: str, receipt_type: str
124    ) -> Optional[str]:
125        return await self.db_pool.simple_select_one_onecol(
126            table="receipts_linearized",
127            keyvalues={
128                "room_id": room_id,
129                "receipt_type": receipt_type,
130                "user_id": user_id,
131            },
132            retcol="event_id",
133            desc="get_own_receipt_for_user",
134            allow_none=True,
135        )
136
137    @cached(num_args=2)
138    async def get_receipts_for_user(
139        self, user_id: str, receipt_type: str
140    ) -> Dict[str, str]:
141        rows = await self.db_pool.simple_select_list(
142            table="receipts_linearized",
143            keyvalues={"user_id": user_id, "receipt_type": receipt_type},
144            retcols=("room_id", "event_id"),
145            desc="get_receipts_for_user",
146        )
147
148        return {row["room_id"]: row["event_id"] for row in rows}
149
150    async def get_receipts_for_user_with_orderings(
151        self, user_id: str, receipt_type: str
152    ) -> JsonDict:
153        def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
154            sql = (
155                "SELECT rl.room_id, rl.event_id,"
156                " e.topological_ordering, e.stream_ordering"
157                " FROM receipts_linearized AS rl"
158                " INNER JOIN events AS e USING (room_id, event_id)"
159                " WHERE rl.room_id = e.room_id"
160                " AND rl.event_id = e.event_id"
161                " AND user_id = ?"
162            )
163            txn.execute(sql, (user_id,))
164            return txn.fetchall()
165
166        rows = await self.db_pool.runInteraction(
167            "get_receipts_for_user_with_orderings", f
168        )
169        return {
170            row[0]: {
171                "event_id": row[1],
172                "topological_ordering": row[2],
173                "stream_ordering": row[3],
174            }
175            for row in rows
176        }
177
178    async def get_linearized_receipts_for_rooms(
179        self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
180    ) -> List[dict]:
181        """Get receipts for multiple rooms for sending to clients.
182
183        Args:
184            room_id: The room IDs to fetch receipts of.
185            to_key: Max stream id to fetch receipts up to.
186            from_key: Min stream id to fetch receipts from. None fetches
187                from the start.
188
189        Returns:
190            A list of receipts.
191        """
192        room_ids = set(room_ids)
193
194        if from_key is not None:
195            # Only ask the database about rooms where there have been new
196            # receipts added since `from_key`
197            room_ids = self._receipts_stream_cache.get_entities_changed(
198                room_ids, from_key
199            )
200
201        results = await self._get_linearized_receipts_for_rooms(
202            room_ids, to_key, from_key=from_key
203        )
204
205        return [ev for res in results.values() for ev in res]
206
207    async def get_linearized_receipts_for_room(
208        self, room_id: str, to_key: int, from_key: Optional[int] = None
209    ) -> List[dict]:
210        """Get receipts for a single room for sending to clients.
211
212        Args:
213            room_ids: The room id.
214            to_key: Max stream id to fetch receipts up to.
215            from_key: Min stream id to fetch receipts from. None fetches
216                from the start.
217
218        Returns:
219            A list of receipts.
220        """
221        if from_key is not None:
222            # Check the cache first to see if any new receipts have been added
223            # since`from_key`. If not we can no-op.
224            if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
225                return []
226
227        return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
228
229    @cached(num_args=3, tree=True)
230    async def _get_linearized_receipts_for_room(
231        self, room_id: str, to_key: int, from_key: Optional[int] = None
232    ) -> List[JsonDict]:
233        """See get_linearized_receipts_for_room"""
234
235        def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
236            if from_key:
237                sql = (
238                    "SELECT * FROM receipts_linearized WHERE"
239                    " room_id = ? AND stream_id > ? AND stream_id <= ?"
240                )
241
242                txn.execute(sql, (room_id, from_key, to_key))
243            else:
244                sql = (
245                    "SELECT * FROM receipts_linearized WHERE"
246                    " room_id = ? AND stream_id <= ?"
247                )
248
249                txn.execute(sql, (room_id, to_key))
250
251            rows = self.db_pool.cursor_to_dict(txn)
252
253            return rows
254
255        rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
256
257        if not rows:
258            return []
259
260        content = {}
261        for row in rows:
262            content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
263                row["user_id"]
264            ] = db_to_json(row["data"])
265
266        return [{"type": "m.receipt", "room_id": room_id, "content": content}]
267
268    @cachedList(
269        cached_method_name="_get_linearized_receipts_for_room",
270        list_name="room_ids",
271        num_args=3,
272    )
273    async def _get_linearized_receipts_for_rooms(
274        self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
275    ) -> Dict[str, List[JsonDict]]:
276        if not room_ids:
277            return {}
278
279        def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
280            if from_key:
281                sql = """
282                    SELECT * FROM receipts_linearized WHERE
283                    stream_id > ? AND stream_id <= ? AND
284                """
285                clause, args = make_in_list_sql_clause(
286                    self.database_engine, "room_id", room_ids
287                )
288
289                txn.execute(sql + clause, [from_key, to_key] + list(args))
290            else:
291                sql = """
292                    SELECT * FROM receipts_linearized WHERE
293                    stream_id <= ? AND
294                """
295
296                clause, args = make_in_list_sql_clause(
297                    self.database_engine, "room_id", room_ids
298                )
299
300                txn.execute(sql + clause, [to_key] + list(args))
301
302            return self.db_pool.cursor_to_dict(txn)
303
304        txn_results = await self.db_pool.runInteraction(
305            "_get_linearized_receipts_for_rooms", f
306        )
307
308        results = {}
309        for row in txn_results:
310            # We want a single event per room, since we want to batch the
311            # receipts by room, event and type.
312            room_event = results.setdefault(
313                row["room_id"],
314                {"type": "m.receipt", "room_id": row["room_id"], "content": {}},
315            )
316
317            # The content is of the form:
318            # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
319            event_entry = room_event["content"].setdefault(row["event_id"], {})
320            receipt_type = event_entry.setdefault(row["receipt_type"], {})
321
322            receipt_type[row["user_id"]] = db_to_json(row["data"])
323
324        results = {
325            room_id: [results[room_id]] if room_id in results else []
326            for room_id in room_ids
327        }
328        return results
329
330    @cached(
331        num_args=2,
332    )
333    async def get_linearized_receipts_for_all_rooms(
334        self, to_key: int, from_key: Optional[int] = None
335    ) -> Dict[str, JsonDict]:
336        """Get receipts for all rooms between two stream_ids, up
337        to a limit of the latest 100 read receipts.
338
339        Args:
340            to_key: Max stream id to fetch receipts up to.
341            from_key: Min stream id to fetch receipts from. None fetches
342                from the start.
343
344        Returns:
345            A dictionary of roomids to a list of receipts.
346        """
347
348        def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
349            if from_key:
350                sql = """
351                    SELECT * FROM receipts_linearized WHERE
352                    stream_id > ? AND stream_id <= ?
353                    ORDER BY stream_id DESC
354                    LIMIT 100
355                """
356                txn.execute(sql, [from_key, to_key])
357            else:
358                sql = """
359                    SELECT * FROM receipts_linearized WHERE
360                    stream_id <= ?
361                    ORDER BY stream_id DESC
362                    LIMIT 100
363                """
364
365                txn.execute(sql, [to_key])
366
367            return self.db_pool.cursor_to_dict(txn)
368
369        txn_results = await self.db_pool.runInteraction(
370            "get_linearized_receipts_for_all_rooms", f
371        )
372
373        results = {}
374        for row in txn_results:
375            # We want a single event per room, since we want to batch the
376            # receipts by room, event and type.
377            room_event = results.setdefault(
378                row["room_id"],
379                {"type": "m.receipt", "room_id": row["room_id"], "content": {}},
380            )
381
382            # The content is of the form:
383            # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
384            event_entry = room_event["content"].setdefault(row["event_id"], {})
385            receipt_type = event_entry.setdefault(row["receipt_type"], {})
386
387            receipt_type[row["user_id"]] = db_to_json(row["data"])
388
389        return results
390
391    async def get_users_sent_receipts_between(
392        self, last_id: int, current_id: int
393    ) -> List[str]:
394        """Get all users who sent receipts between `last_id` exclusive and
395        `current_id` inclusive.
396
397        Returns:
398            The list of users.
399        """
400
401        if last_id == current_id:
402            return defer.succeed([])
403
404        def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
405            sql = """
406                SELECT DISTINCT user_id FROM receipts_linearized
407                WHERE ? < stream_id AND stream_id <= ?
408            """
409            txn.execute(sql, (last_id, current_id))
410
411            return [r[0] for r in txn]
412
413        return await self.db_pool.runInteraction(
414            "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn
415        )
416
417    async def get_all_updated_receipts(
418        self, instance_name: str, last_id: int, current_id: int, limit: int
419    ) -> Tuple[List[Tuple[int, list]], int, bool]:
420        """Get updates for receipts replication stream.
421
422        Args:
423            instance_name: The writer we want to fetch updates from. Unused
424                here since there is only ever one writer.
425            last_id: The token to fetch updates from. Exclusive.
426            current_id: The token to fetch updates up to. Inclusive.
427            limit: The requested limit for the number of rows to return. The
428                function may return more or fewer rows.
429
430        Returns:
431            A tuple consisting of: the updates, a token to use to fetch
432            subsequent updates, and whether we returned fewer rows than exists
433            between the requested tokens due to the limit.
434
435            The token returned can be used in a subsequent call to this
436            function to get further updatees.
437
438            The updates are a list of 2-tuples of stream ID and the row data
439        """
440
441        if last_id == current_id:
442            return [], current_id, False
443
444        def get_all_updated_receipts_txn(
445            txn: LoggingTransaction,
446        ) -> Tuple[List[Tuple[int, list]], int, bool]:
447            sql = """
448                SELECT stream_id, room_id, receipt_type, user_id, event_id, data
449                FROM receipts_linearized
450                WHERE ? < stream_id AND stream_id <= ?
451                ORDER BY stream_id ASC
452                LIMIT ?
453            """
454            txn.execute(sql, (last_id, current_id, limit))
455
456            updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn]
457
458            limited = False
459            upper_bound = current_id
460
461            if len(updates) == limit:
462                limited = True
463                upper_bound = updates[-1][0]
464
465            return updates, upper_bound, limited
466
467        return await self.db_pool.runInteraction(
468            "get_all_updated_receipts", get_all_updated_receipts_txn
469        )
470
471    def _invalidate_get_users_with_receipts_in_room(
472        self, room_id: str, receipt_type: str, user_id: str
473    ) -> None:
474        if receipt_type != ReceiptTypes.READ:
475            return
476
477        res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
478            room_id, None, update_metrics=False
479        )
480
481        if res and user_id in res:
482            # We'd only be adding to the set, so no point invalidating if the
483            # user is already there
484            return
485
486        self.get_users_with_read_receipts_in_room.invalidate((room_id,))
487
488    def invalidate_caches_for_receipt(
489        self, room_id: str, receipt_type: str, user_id: str
490    ) -> None:
491        self.get_receipts_for_user.invalidate((user_id, receipt_type))
492        self._get_linearized_receipts_for_room.invalidate((room_id,))
493        self.get_last_receipt_event_id_for_user.invalidate(
494            (user_id, room_id, receipt_type)
495        )
496        self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
497        self.get_receipts_for_room.invalidate((room_id, receipt_type))
498
499    def process_replication_rows(self, stream_name, instance_name, token, rows):
500        if stream_name == ReceiptsStream.NAME:
501            self._receipts_id_gen.advance(instance_name, token)
502            for row in rows:
503                self.invalidate_caches_for_receipt(
504                    row.room_id, row.receipt_type, row.user_id
505                )
506                self._receipts_stream_cache.entity_has_changed(row.room_id, token)
507
508        return super().process_replication_rows(stream_name, instance_name, token, rows)
509
510    def insert_linearized_receipt_txn(
511        self,
512        txn: LoggingTransaction,
513        room_id: str,
514        receipt_type: str,
515        user_id: str,
516        event_id: str,
517        data: JsonDict,
518        stream_id: int,
519    ) -> Optional[int]:
520        """Inserts a read-receipt into the database if it's newer than the current RR
521
522        Returns:
523            None if the RR is older than the current RR
524            otherwise, the rx timestamp of the event that the RR corresponds to
525                (or 0 if the event is unknown)
526        """
527        assert self._can_write_to_receipts
528
529        res = self.db_pool.simple_select_one_txn(
530            txn,
531            table="events",
532            retcols=["stream_ordering", "received_ts"],
533            keyvalues={"event_id": event_id},
534            allow_none=True,
535        )
536
537        stream_ordering = int(res["stream_ordering"]) if res else None
538        rx_ts = res["received_ts"] if res else 0
539
540        # We don't want to clobber receipts for more recent events, so we
541        # have to compare orderings of existing receipts
542        if stream_ordering is not None:
543            sql = (
544                "SELECT stream_ordering, event_id FROM events"
545                " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
546                " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
547            )
548            txn.execute(sql, (room_id, receipt_type, user_id))
549
550            for so, eid in txn:
551                if int(so) >= stream_ordering:
552                    logger.debug(
553                        "Ignoring new receipt for %s in favour of existing "
554                        "one for later event %s",
555                        event_id,
556                        eid,
557                    )
558                    return None
559
560        txn.call_after(
561            self.invalidate_caches_for_receipt, room_id, receipt_type, user_id
562        )
563
564        txn.call_after(
565            self._receipts_stream_cache.entity_has_changed, room_id, stream_id
566        )
567
568        self.db_pool.simple_upsert_txn(
569            txn,
570            table="receipts_linearized",
571            keyvalues={
572                "room_id": room_id,
573                "receipt_type": receipt_type,
574                "user_id": user_id,
575            },
576            values={
577                "stream_id": stream_id,
578                "event_id": event_id,
579                "data": json_encoder.encode(data),
580            },
581            # receipts_linearized has a unique constraint on
582            # (user_id, room_id, receipt_type), so no need to lock
583            lock=False,
584        )
585
586        if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
587            self._remove_old_push_actions_before_txn(
588                txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
589            )
590
591        return rx_ts
592
593    async def insert_receipt(
594        self,
595        room_id: str,
596        receipt_type: str,
597        user_id: str,
598        event_ids: List[str],
599        data: dict,
600    ) -> Optional[Tuple[int, int]]:
601        """Insert a receipt, either from local client or remote server.
602
603        Automatically does conversion between linearized and graph
604        representations.
605        """
606        assert self._can_write_to_receipts
607
608        if not event_ids:
609            return None
610
611        if len(event_ids) == 1:
612            linearized_event_id = event_ids[0]
613        else:
614            # we need to points in graph -> linearized form.
615            # TODO: Make this better.
616            def graph_to_linear(txn: LoggingTransaction) -> str:
617                clause, args = make_in_list_sql_clause(
618                    self.database_engine, "event_id", event_ids
619                )
620
621                sql = """
622                    SELECT event_id WHERE room_id = ? AND stream_ordering IN (
623                        SELECT max(stream_ordering) WHERE %s
624                    )
625                """ % (
626                    clause,
627                )
628
629                txn.execute(sql, [room_id] + list(args))
630                rows = txn.fetchall()
631                if rows:
632                    return rows[0][0]
633                else:
634                    raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
635
636            linearized_event_id = await self.db_pool.runInteraction(
637                "insert_receipt_conv", graph_to_linear
638            )
639
640        async with self._receipts_id_gen.get_next() as stream_id:
641            event_ts = await self.db_pool.runInteraction(
642                "insert_linearized_receipt",
643                self.insert_linearized_receipt_txn,
644                room_id,
645                receipt_type,
646                user_id,
647                linearized_event_id,
648                data,
649                stream_id=stream_id,
650            )
651
652        if event_ts is None:
653            return None
654
655        now = self._clock.time_msec()
656        logger.debug(
657            "RR for event %s in %s (%i ms old)",
658            linearized_event_id,
659            room_id,
660            now - event_ts,
661        )
662
663        await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
664
665        max_persisted_id = self._receipts_id_gen.get_current_token()
666
667        return stream_id, max_persisted_id
668
669    async def insert_graph_receipt(
670        self,
671        room_id: str,
672        receipt_type: str,
673        user_id: str,
674        event_ids: List[str],
675        data: JsonDict,
676    ) -> None:
677        assert self._can_write_to_receipts
678
679        await self.db_pool.runInteraction(
680            "insert_graph_receipt",
681            self.insert_graph_receipt_txn,
682            room_id,
683            receipt_type,
684            user_id,
685            event_ids,
686            data,
687        )
688
689    def insert_graph_receipt_txn(
690        self,
691        txn: LoggingTransaction,
692        room_id: str,
693        receipt_type: str,
694        user_id: str,
695        event_ids: List[str],
696        data: JsonDict,
697    ) -> None:
698        assert self._can_write_to_receipts
699
700        txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
701        txn.call_after(
702            self._invalidate_get_users_with_receipts_in_room,
703            room_id,
704            receipt_type,
705            user_id,
706        )
707        txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
708        # FIXME: This shouldn't invalidate the whole cache
709        txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,))
710
711        self.db_pool.simple_delete_txn(
712            txn,
713            table="receipts_graph",
714            keyvalues={
715                "room_id": room_id,
716                "receipt_type": receipt_type,
717                "user_id": user_id,
718            },
719        )
720        self.db_pool.simple_insert_txn(
721            txn,
722            table="receipts_graph",
723            values={
724                "room_id": room_id,
725                "receipt_type": receipt_type,
726                "user_id": user_id,
727                "event_ids": json_encoder.encode(event_ids),
728                "data": json_encoder.encode(data),
729            },
730        )
731
732
733class ReceiptsStore(ReceiptsWorkerStore):
734    pass
735