1# Copyright 2014-2016 OpenMarket Ltd 2# Copyright 2018 New Vector Ltd 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16import logging 17from typing import ( 18 TYPE_CHECKING, 19 Any, 20 Collection, 21 Dict, 22 Iterable, 23 List, 24 Optional, 25 Set, 26 Tuple, 27) 28 29from twisted.internet import defer 30 31from synapse.api.constants import ReceiptTypes 32from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker 33from synapse.replication.tcp.streams import ReceiptsStream 34from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause 35from synapse.storage.database import ( 36 DatabasePool, 37 LoggingDatabaseConnection, 38 LoggingTransaction, 39) 40from synapse.storage.engines import PostgresEngine 41from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator 42from synapse.types import JsonDict 43from synapse.util import json_encoder 44from synapse.util.caches.descriptors import cached, cachedList 45from synapse.util.caches.stream_change_cache import StreamChangeCache 46 47if TYPE_CHECKING: 48 from synapse.server import HomeServer 49 50logger = logging.getLogger(__name__) 51 52 53class ReceiptsWorkerStore(SQLBaseStore): 54 def __init__( 55 self, 56 database: DatabasePool, 57 db_conn: LoggingDatabaseConnection, 58 hs: "HomeServer", 59 ): 60 self._instance_name = hs.get_instance_name() 61 62 if isinstance(database.engine, PostgresEngine): 63 self._can_write_to_receipts = ( 64 self._instance_name in hs.config.worker.writers.receipts 65 ) 66 67 self._receipts_id_gen = MultiWriterIdGenerator( 68 db_conn=db_conn, 69 db=database, 70 stream_name="receipts", 71 instance_name=self._instance_name, 72 tables=[("receipts_linearized", "instance_name", "stream_id")], 73 sequence_name="receipts_sequence", 74 writers=hs.config.worker.writers.receipts, 75 ) 76 else: 77 self._can_write_to_receipts = True 78 79 # We shouldn't be running in worker mode with SQLite, but its useful 80 # to support it for unit tests. 81 # 82 # If this process is the writer than we need to use 83 # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets 84 # updated over replication. (Multiple writers are not supported for 85 # SQLite). 86 if hs.get_instance_name() in hs.config.worker.writers.receipts: 87 self._receipts_id_gen = StreamIdGenerator( 88 db_conn, "receipts_linearized", "stream_id" 89 ) 90 else: 91 self._receipts_id_gen = SlavedIdTracker( 92 db_conn, "receipts_linearized", "stream_id" 93 ) 94 95 super().__init__(database, db_conn, hs) 96 97 self._receipts_stream_cache = StreamChangeCache( 98 "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() 99 ) 100 101 def get_max_receipt_stream_id(self) -> int: 102 """Get the current max stream ID for receipts stream""" 103 return self._receipts_id_gen.get_current_token() 104 105 @cached() 106 async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]: 107 receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ) 108 return {r["user_id"] for r in receipts} 109 110 @cached(num_args=2) 111 async def get_receipts_for_room( 112 self, room_id: str, receipt_type: str 113 ) -> List[Dict[str, Any]]: 114 return await self.db_pool.simple_select_list( 115 table="receipts_linearized", 116 keyvalues={"room_id": room_id, "receipt_type": receipt_type}, 117 retcols=("user_id", "event_id"), 118 desc="get_receipts_for_room", 119 ) 120 121 @cached(num_args=3) 122 async def get_last_receipt_event_id_for_user( 123 self, user_id: str, room_id: str, receipt_type: str 124 ) -> Optional[str]: 125 return await self.db_pool.simple_select_one_onecol( 126 table="receipts_linearized", 127 keyvalues={ 128 "room_id": room_id, 129 "receipt_type": receipt_type, 130 "user_id": user_id, 131 }, 132 retcol="event_id", 133 desc="get_own_receipt_for_user", 134 allow_none=True, 135 ) 136 137 @cached(num_args=2) 138 async def get_receipts_for_user( 139 self, user_id: str, receipt_type: str 140 ) -> Dict[str, str]: 141 rows = await self.db_pool.simple_select_list( 142 table="receipts_linearized", 143 keyvalues={"user_id": user_id, "receipt_type": receipt_type}, 144 retcols=("room_id", "event_id"), 145 desc="get_receipts_for_user", 146 ) 147 148 return {row["room_id"]: row["event_id"] for row in rows} 149 150 async def get_receipts_for_user_with_orderings( 151 self, user_id: str, receipt_type: str 152 ) -> JsonDict: 153 def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]: 154 sql = ( 155 "SELECT rl.room_id, rl.event_id," 156 " e.topological_ordering, e.stream_ordering" 157 " FROM receipts_linearized AS rl" 158 " INNER JOIN events AS e USING (room_id, event_id)" 159 " WHERE rl.room_id = e.room_id" 160 " AND rl.event_id = e.event_id" 161 " AND user_id = ?" 162 ) 163 txn.execute(sql, (user_id,)) 164 return txn.fetchall() 165 166 rows = await self.db_pool.runInteraction( 167 "get_receipts_for_user_with_orderings", f 168 ) 169 return { 170 row[0]: { 171 "event_id": row[1], 172 "topological_ordering": row[2], 173 "stream_ordering": row[3], 174 } 175 for row in rows 176 } 177 178 async def get_linearized_receipts_for_rooms( 179 self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None 180 ) -> List[dict]: 181 """Get receipts for multiple rooms for sending to clients. 182 183 Args: 184 room_id: The room IDs to fetch receipts of. 185 to_key: Max stream id to fetch receipts up to. 186 from_key: Min stream id to fetch receipts from. None fetches 187 from the start. 188 189 Returns: 190 A list of receipts. 191 """ 192 room_ids = set(room_ids) 193 194 if from_key is not None: 195 # Only ask the database about rooms where there have been new 196 # receipts added since `from_key` 197 room_ids = self._receipts_stream_cache.get_entities_changed( 198 room_ids, from_key 199 ) 200 201 results = await self._get_linearized_receipts_for_rooms( 202 room_ids, to_key, from_key=from_key 203 ) 204 205 return [ev for res in results.values() for ev in res] 206 207 async def get_linearized_receipts_for_room( 208 self, room_id: str, to_key: int, from_key: Optional[int] = None 209 ) -> List[dict]: 210 """Get receipts for a single room for sending to clients. 211 212 Args: 213 room_ids: The room id. 214 to_key: Max stream id to fetch receipts up to. 215 from_key: Min stream id to fetch receipts from. None fetches 216 from the start. 217 218 Returns: 219 A list of receipts. 220 """ 221 if from_key is not None: 222 # Check the cache first to see if any new receipts have been added 223 # since`from_key`. If not we can no-op. 224 if not self._receipts_stream_cache.has_entity_changed(room_id, from_key): 225 return [] 226 227 return await self._get_linearized_receipts_for_room(room_id, to_key, from_key) 228 229 @cached(num_args=3, tree=True) 230 async def _get_linearized_receipts_for_room( 231 self, room_id: str, to_key: int, from_key: Optional[int] = None 232 ) -> List[JsonDict]: 233 """See get_linearized_receipts_for_room""" 234 235 def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: 236 if from_key: 237 sql = ( 238 "SELECT * FROM receipts_linearized WHERE" 239 " room_id = ? AND stream_id > ? AND stream_id <= ?" 240 ) 241 242 txn.execute(sql, (room_id, from_key, to_key)) 243 else: 244 sql = ( 245 "SELECT * FROM receipts_linearized WHERE" 246 " room_id = ? AND stream_id <= ?" 247 ) 248 249 txn.execute(sql, (room_id, to_key)) 250 251 rows = self.db_pool.cursor_to_dict(txn) 252 253 return rows 254 255 rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f) 256 257 if not rows: 258 return [] 259 260 content = {} 261 for row in rows: 262 content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[ 263 row["user_id"] 264 ] = db_to_json(row["data"]) 265 266 return [{"type": "m.receipt", "room_id": room_id, "content": content}] 267 268 @cachedList( 269 cached_method_name="_get_linearized_receipts_for_room", 270 list_name="room_ids", 271 num_args=3, 272 ) 273 async def _get_linearized_receipts_for_rooms( 274 self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None 275 ) -> Dict[str, List[JsonDict]]: 276 if not room_ids: 277 return {} 278 279 def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: 280 if from_key: 281 sql = """ 282 SELECT * FROM receipts_linearized WHERE 283 stream_id > ? AND stream_id <= ? AND 284 """ 285 clause, args = make_in_list_sql_clause( 286 self.database_engine, "room_id", room_ids 287 ) 288 289 txn.execute(sql + clause, [from_key, to_key] + list(args)) 290 else: 291 sql = """ 292 SELECT * FROM receipts_linearized WHERE 293 stream_id <= ? AND 294 """ 295 296 clause, args = make_in_list_sql_clause( 297 self.database_engine, "room_id", room_ids 298 ) 299 300 txn.execute(sql + clause, [to_key] + list(args)) 301 302 return self.db_pool.cursor_to_dict(txn) 303 304 txn_results = await self.db_pool.runInteraction( 305 "_get_linearized_receipts_for_rooms", f 306 ) 307 308 results = {} 309 for row in txn_results: 310 # We want a single event per room, since we want to batch the 311 # receipts by room, event and type. 312 room_event = results.setdefault( 313 row["room_id"], 314 {"type": "m.receipt", "room_id": row["room_id"], "content": {}}, 315 ) 316 317 # The content is of the form: 318 # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } 319 event_entry = room_event["content"].setdefault(row["event_id"], {}) 320 receipt_type = event_entry.setdefault(row["receipt_type"], {}) 321 322 receipt_type[row["user_id"]] = db_to_json(row["data"]) 323 324 results = { 325 room_id: [results[room_id]] if room_id in results else [] 326 for room_id in room_ids 327 } 328 return results 329 330 @cached( 331 num_args=2, 332 ) 333 async def get_linearized_receipts_for_all_rooms( 334 self, to_key: int, from_key: Optional[int] = None 335 ) -> Dict[str, JsonDict]: 336 """Get receipts for all rooms between two stream_ids, up 337 to a limit of the latest 100 read receipts. 338 339 Args: 340 to_key: Max stream id to fetch receipts up to. 341 from_key: Min stream id to fetch receipts from. None fetches 342 from the start. 343 344 Returns: 345 A dictionary of roomids to a list of receipts. 346 """ 347 348 def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: 349 if from_key: 350 sql = """ 351 SELECT * FROM receipts_linearized WHERE 352 stream_id > ? AND stream_id <= ? 353 ORDER BY stream_id DESC 354 LIMIT 100 355 """ 356 txn.execute(sql, [from_key, to_key]) 357 else: 358 sql = """ 359 SELECT * FROM receipts_linearized WHERE 360 stream_id <= ? 361 ORDER BY stream_id DESC 362 LIMIT 100 363 """ 364 365 txn.execute(sql, [to_key]) 366 367 return self.db_pool.cursor_to_dict(txn) 368 369 txn_results = await self.db_pool.runInteraction( 370 "get_linearized_receipts_for_all_rooms", f 371 ) 372 373 results = {} 374 for row in txn_results: 375 # We want a single event per room, since we want to batch the 376 # receipts by room, event and type. 377 room_event = results.setdefault( 378 row["room_id"], 379 {"type": "m.receipt", "room_id": row["room_id"], "content": {}}, 380 ) 381 382 # The content is of the form: 383 # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } 384 event_entry = room_event["content"].setdefault(row["event_id"], {}) 385 receipt_type = event_entry.setdefault(row["receipt_type"], {}) 386 387 receipt_type[row["user_id"]] = db_to_json(row["data"]) 388 389 return results 390 391 async def get_users_sent_receipts_between( 392 self, last_id: int, current_id: int 393 ) -> List[str]: 394 """Get all users who sent receipts between `last_id` exclusive and 395 `current_id` inclusive. 396 397 Returns: 398 The list of users. 399 """ 400 401 if last_id == current_id: 402 return defer.succeed([]) 403 404 def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]: 405 sql = """ 406 SELECT DISTINCT user_id FROM receipts_linearized 407 WHERE ? < stream_id AND stream_id <= ? 408 """ 409 txn.execute(sql, (last_id, current_id)) 410 411 return [r[0] for r in txn] 412 413 return await self.db_pool.runInteraction( 414 "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn 415 ) 416 417 async def get_all_updated_receipts( 418 self, instance_name: str, last_id: int, current_id: int, limit: int 419 ) -> Tuple[List[Tuple[int, list]], int, bool]: 420 """Get updates for receipts replication stream. 421 422 Args: 423 instance_name: The writer we want to fetch updates from. Unused 424 here since there is only ever one writer. 425 last_id: The token to fetch updates from. Exclusive. 426 current_id: The token to fetch updates up to. Inclusive. 427 limit: The requested limit for the number of rows to return. The 428 function may return more or fewer rows. 429 430 Returns: 431 A tuple consisting of: the updates, a token to use to fetch 432 subsequent updates, and whether we returned fewer rows than exists 433 between the requested tokens due to the limit. 434 435 The token returned can be used in a subsequent call to this 436 function to get further updatees. 437 438 The updates are a list of 2-tuples of stream ID and the row data 439 """ 440 441 if last_id == current_id: 442 return [], current_id, False 443 444 def get_all_updated_receipts_txn( 445 txn: LoggingTransaction, 446 ) -> Tuple[List[Tuple[int, list]], int, bool]: 447 sql = """ 448 SELECT stream_id, room_id, receipt_type, user_id, event_id, data 449 FROM receipts_linearized 450 WHERE ? < stream_id AND stream_id <= ? 451 ORDER BY stream_id ASC 452 LIMIT ? 453 """ 454 txn.execute(sql, (last_id, current_id, limit)) 455 456 updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn] 457 458 limited = False 459 upper_bound = current_id 460 461 if len(updates) == limit: 462 limited = True 463 upper_bound = updates[-1][0] 464 465 return updates, upper_bound, limited 466 467 return await self.db_pool.runInteraction( 468 "get_all_updated_receipts", get_all_updated_receipts_txn 469 ) 470 471 def _invalidate_get_users_with_receipts_in_room( 472 self, room_id: str, receipt_type: str, user_id: str 473 ) -> None: 474 if receipt_type != ReceiptTypes.READ: 475 return 476 477 res = self.get_users_with_read_receipts_in_room.cache.get_immediate( 478 room_id, None, update_metrics=False 479 ) 480 481 if res and user_id in res: 482 # We'd only be adding to the set, so no point invalidating if the 483 # user is already there 484 return 485 486 self.get_users_with_read_receipts_in_room.invalidate((room_id,)) 487 488 def invalidate_caches_for_receipt( 489 self, room_id: str, receipt_type: str, user_id: str 490 ) -> None: 491 self.get_receipts_for_user.invalidate((user_id, receipt_type)) 492 self._get_linearized_receipts_for_room.invalidate((room_id,)) 493 self.get_last_receipt_event_id_for_user.invalidate( 494 (user_id, room_id, receipt_type) 495 ) 496 self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) 497 self.get_receipts_for_room.invalidate((room_id, receipt_type)) 498 499 def process_replication_rows(self, stream_name, instance_name, token, rows): 500 if stream_name == ReceiptsStream.NAME: 501 self._receipts_id_gen.advance(instance_name, token) 502 for row in rows: 503 self.invalidate_caches_for_receipt( 504 row.room_id, row.receipt_type, row.user_id 505 ) 506 self._receipts_stream_cache.entity_has_changed(row.room_id, token) 507 508 return super().process_replication_rows(stream_name, instance_name, token, rows) 509 510 def insert_linearized_receipt_txn( 511 self, 512 txn: LoggingTransaction, 513 room_id: str, 514 receipt_type: str, 515 user_id: str, 516 event_id: str, 517 data: JsonDict, 518 stream_id: int, 519 ) -> Optional[int]: 520 """Inserts a read-receipt into the database if it's newer than the current RR 521 522 Returns: 523 None if the RR is older than the current RR 524 otherwise, the rx timestamp of the event that the RR corresponds to 525 (or 0 if the event is unknown) 526 """ 527 assert self._can_write_to_receipts 528 529 res = self.db_pool.simple_select_one_txn( 530 txn, 531 table="events", 532 retcols=["stream_ordering", "received_ts"], 533 keyvalues={"event_id": event_id}, 534 allow_none=True, 535 ) 536 537 stream_ordering = int(res["stream_ordering"]) if res else None 538 rx_ts = res["received_ts"] if res else 0 539 540 # We don't want to clobber receipts for more recent events, so we 541 # have to compare orderings of existing receipts 542 if stream_ordering is not None: 543 sql = ( 544 "SELECT stream_ordering, event_id FROM events" 545 " INNER JOIN receipts_linearized as r USING (event_id, room_id)" 546 " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" 547 ) 548 txn.execute(sql, (room_id, receipt_type, user_id)) 549 550 for so, eid in txn: 551 if int(so) >= stream_ordering: 552 logger.debug( 553 "Ignoring new receipt for %s in favour of existing " 554 "one for later event %s", 555 event_id, 556 eid, 557 ) 558 return None 559 560 txn.call_after( 561 self.invalidate_caches_for_receipt, room_id, receipt_type, user_id 562 ) 563 564 txn.call_after( 565 self._receipts_stream_cache.entity_has_changed, room_id, stream_id 566 ) 567 568 self.db_pool.simple_upsert_txn( 569 txn, 570 table="receipts_linearized", 571 keyvalues={ 572 "room_id": room_id, 573 "receipt_type": receipt_type, 574 "user_id": user_id, 575 }, 576 values={ 577 "stream_id": stream_id, 578 "event_id": event_id, 579 "data": json_encoder.encode(data), 580 }, 581 # receipts_linearized has a unique constraint on 582 # (user_id, room_id, receipt_type), so no need to lock 583 lock=False, 584 ) 585 586 if receipt_type == ReceiptTypes.READ and stream_ordering is not None: 587 self._remove_old_push_actions_before_txn( 588 txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering 589 ) 590 591 return rx_ts 592 593 async def insert_receipt( 594 self, 595 room_id: str, 596 receipt_type: str, 597 user_id: str, 598 event_ids: List[str], 599 data: dict, 600 ) -> Optional[Tuple[int, int]]: 601 """Insert a receipt, either from local client or remote server. 602 603 Automatically does conversion between linearized and graph 604 representations. 605 """ 606 assert self._can_write_to_receipts 607 608 if not event_ids: 609 return None 610 611 if len(event_ids) == 1: 612 linearized_event_id = event_ids[0] 613 else: 614 # we need to points in graph -> linearized form. 615 # TODO: Make this better. 616 def graph_to_linear(txn: LoggingTransaction) -> str: 617 clause, args = make_in_list_sql_clause( 618 self.database_engine, "event_id", event_ids 619 ) 620 621 sql = """ 622 SELECT event_id WHERE room_id = ? AND stream_ordering IN ( 623 SELECT max(stream_ordering) WHERE %s 624 ) 625 """ % ( 626 clause, 627 ) 628 629 txn.execute(sql, [room_id] + list(args)) 630 rows = txn.fetchall() 631 if rows: 632 return rows[0][0] 633 else: 634 raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) 635 636 linearized_event_id = await self.db_pool.runInteraction( 637 "insert_receipt_conv", graph_to_linear 638 ) 639 640 async with self._receipts_id_gen.get_next() as stream_id: 641 event_ts = await self.db_pool.runInteraction( 642 "insert_linearized_receipt", 643 self.insert_linearized_receipt_txn, 644 room_id, 645 receipt_type, 646 user_id, 647 linearized_event_id, 648 data, 649 stream_id=stream_id, 650 ) 651 652 if event_ts is None: 653 return None 654 655 now = self._clock.time_msec() 656 logger.debug( 657 "RR for event %s in %s (%i ms old)", 658 linearized_event_id, 659 room_id, 660 now - event_ts, 661 ) 662 663 await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data) 664 665 max_persisted_id = self._receipts_id_gen.get_current_token() 666 667 return stream_id, max_persisted_id 668 669 async def insert_graph_receipt( 670 self, 671 room_id: str, 672 receipt_type: str, 673 user_id: str, 674 event_ids: List[str], 675 data: JsonDict, 676 ) -> None: 677 assert self._can_write_to_receipts 678 679 await self.db_pool.runInteraction( 680 "insert_graph_receipt", 681 self.insert_graph_receipt_txn, 682 room_id, 683 receipt_type, 684 user_id, 685 event_ids, 686 data, 687 ) 688 689 def insert_graph_receipt_txn( 690 self, 691 txn: LoggingTransaction, 692 room_id: str, 693 receipt_type: str, 694 user_id: str, 695 event_ids: List[str], 696 data: JsonDict, 697 ) -> None: 698 assert self._can_write_to_receipts 699 700 txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) 701 txn.call_after( 702 self._invalidate_get_users_with_receipts_in_room, 703 room_id, 704 receipt_type, 705 user_id, 706 ) 707 txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) 708 # FIXME: This shouldn't invalidate the whole cache 709 txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,)) 710 711 self.db_pool.simple_delete_txn( 712 txn, 713 table="receipts_graph", 714 keyvalues={ 715 "room_id": room_id, 716 "receipt_type": receipt_type, 717 "user_id": user_id, 718 }, 719 ) 720 self.db_pool.simple_insert_txn( 721 txn, 722 table="receipts_graph", 723 values={ 724 "room_id": room_id, 725 "receipt_type": receipt_type, 726 "user_id": user_id, 727 "event_ids": json_encoder.encode(event_ids), 728 "data": json_encoder.encode(data), 729 }, 730 ) 731 732 733class ReceiptsStore(ReceiptsWorkerStore): 734 pass 735