1# Copyright 2018 New Vector Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import logging 16import threading 17from typing import ( 18 TYPE_CHECKING, 19 Any, 20 Collection, 21 Container, 22 Dict, 23 Iterable, 24 List, 25 NoReturn, 26 Optional, 27 Set, 28 Tuple, 29 cast, 30 overload, 31) 32 33import attr 34from constantly import NamedConstant, Names 35from prometheus_client import Gauge 36from typing_extensions import Literal 37 38from twisted.internet import defer 39 40from synapse.api.constants import EventTypes 41from synapse.api.errors import NotFoundError, SynapseError 42from synapse.api.room_versions import ( 43 KNOWN_ROOM_VERSIONS, 44 EventFormatVersions, 45 RoomVersion, 46 RoomVersions, 47) 48from synapse.events import EventBase, make_event_from_dict 49from synapse.events.snapshot import EventContext 50from synapse.events.utils import prune_event 51from synapse.logging.context import ( 52 PreserveLoggingContext, 53 current_context, 54 make_deferred_yieldable, 55) 56from synapse.metrics.background_process_metrics import ( 57 run_as_background_process, 58 wrap_as_background_process, 59) 60from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker 61from synapse.replication.tcp.streams import BackfillStream 62from synapse.replication.tcp.streams.events import EventsStream 63from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause 64from synapse.storage.database import ( 65 DatabasePool, 66 LoggingDatabaseConnection, 67 LoggingTransaction, 68) 69from synapse.storage.engines import PostgresEngine 70from synapse.storage.types import Cursor 71from synapse.storage.util.id_generators import ( 72 AbstractStreamIdTracker, 73 MultiWriterIdGenerator, 74 StreamIdGenerator, 75) 76from synapse.storage.util.sequence import build_sequence_generator 77from synapse.types import JsonDict, get_domain_from_id 78from synapse.util import unwrapFirstError 79from synapse.util.async_helpers import ObservableDeferred 80from synapse.util.caches.descriptors import cached, cachedList 81from synapse.util.caches.lrucache import LruCache 82from synapse.util.iterutils import batch_iter 83from synapse.util.metrics import Measure 84 85if TYPE_CHECKING: 86 from synapse.server import HomeServer 87 88logger = logging.getLogger(__name__) 89 90 91# These values are used in the `enqueue_event` and `_fetch_loop` methods to 92# control how we batch/bulk fetch events from the database. 93# The values are plucked out of thing air to make initial sync run faster 94# on jki.re 95# TODO: Make these configurable. 96EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events 97EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events 98EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events 99 100 101event_fetch_ongoing_gauge = Gauge( 102 "synapse_event_fetch_ongoing", 103 "The number of event fetchers that are running", 104) 105 106 107@attr.s(slots=True, auto_attribs=True) 108class EventCacheEntry: 109 event: EventBase 110 redacted_event: Optional[EventBase] 111 112 113@attr.s(slots=True, frozen=True, auto_attribs=True) 114class _EventRow: 115 """ 116 An event, as pulled from the database. 117 118 Properties: 119 event_id: The event ID of the event. 120 121 stream_ordering: stream ordering for this event 122 123 json: json-encoded event structure 124 125 internal_metadata: json-encoded internal metadata dict 126 127 format_version: The format of the event. Hopefully one of EventFormatVersions. 128 'None' means the event predates EventFormatVersions (so the event is format V1). 129 130 room_version_id: The version of the room which contains the event. Hopefully 131 one of RoomVersions. 132 133 Due to historical reasons, there may be a few events in the database which 134 do not have an associated room; in this case None will be returned here. 135 136 rejected_reason: if the event was rejected, the reason why. 137 138 redactions: a list of event-ids which (claim to) redact this event. 139 140 outlier: True if this event is an outlier. 141 """ 142 143 event_id: str 144 stream_ordering: int 145 json: str 146 internal_metadata: str 147 format_version: Optional[int] 148 room_version_id: Optional[str] 149 rejected_reason: Optional[str] 150 redactions: List[str] 151 outlier: bool 152 153 154class EventRedactBehaviour(Names): 155 """ 156 What to do when retrieving a redacted event from the database. 157 """ 158 159 AS_IS = NamedConstant() 160 REDACT = NamedConstant() 161 BLOCK = NamedConstant() 162 163 164class EventsWorkerStore(SQLBaseStore): 165 # Whether to use dedicated DB threads for event fetching. This is only used 166 # if there are multiple DB threads available. When used will lock the DB 167 # thread for periods of time (so unit tests want to disable this when they 168 # run DB transactions on the main thread). See EVENT_QUEUE_* for more 169 # options controlling this. 170 USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True 171 172 def __init__( 173 self, 174 database: DatabasePool, 175 db_conn: LoggingDatabaseConnection, 176 hs: "HomeServer", 177 ): 178 super().__init__(database, db_conn, hs) 179 180 self._stream_id_gen: AbstractStreamIdTracker 181 self._backfill_id_gen: AbstractStreamIdTracker 182 if isinstance(database.engine, PostgresEngine): 183 # If we're using Postgres than we can use `MultiWriterIdGenerator` 184 # regardless of whether this process writes to the streams or not. 185 self._stream_id_gen = MultiWriterIdGenerator( 186 db_conn=db_conn, 187 db=database, 188 stream_name="events", 189 instance_name=hs.get_instance_name(), 190 tables=[("events", "instance_name", "stream_ordering")], 191 sequence_name="events_stream_seq", 192 writers=hs.config.worker.writers.events, 193 ) 194 self._backfill_id_gen = MultiWriterIdGenerator( 195 db_conn=db_conn, 196 db=database, 197 stream_name="backfill", 198 instance_name=hs.get_instance_name(), 199 tables=[("events", "instance_name", "stream_ordering")], 200 sequence_name="events_backfill_stream_seq", 201 positive=False, 202 writers=hs.config.worker.writers.events, 203 ) 204 else: 205 # We shouldn't be running in worker mode with SQLite, but its useful 206 # to support it for unit tests. 207 # 208 # If this process is the writer than we need to use 209 # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets 210 # updated over replication. (Multiple writers are not supported for 211 # SQLite). 212 if hs.get_instance_name() in hs.config.worker.writers.events: 213 self._stream_id_gen = StreamIdGenerator( 214 db_conn, 215 "events", 216 "stream_ordering", 217 ) 218 self._backfill_id_gen = StreamIdGenerator( 219 db_conn, 220 "events", 221 "stream_ordering", 222 step=-1, 223 extra_tables=[("ex_outlier_stream", "event_stream_ordering")], 224 ) 225 else: 226 self._stream_id_gen = SlavedIdTracker( 227 db_conn, "events", "stream_ordering" 228 ) 229 self._backfill_id_gen = SlavedIdTracker( 230 db_conn, "events", "stream_ordering", step=-1 231 ) 232 233 if hs.config.worker.run_background_tasks: 234 # We periodically clean out old transaction ID mappings 235 self._clock.looping_call( 236 self._cleanup_old_transaction_ids, 237 5 * 60 * 1000, 238 ) 239 240 self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache( 241 cache_name="*getEvent*", 242 max_size=hs.config.caches.event_cache_size, 243 ) 244 245 # Map from event ID to a deferred that will result in a map from event 246 # ID to cache entry. Note that the returned dict may not have the 247 # requested event in it if the event isn't in the DB. 248 self._current_event_fetches: Dict[ 249 str, ObservableDeferred[Dict[str, EventCacheEntry]] 250 ] = {} 251 252 self._event_fetch_lock = threading.Condition() 253 self._event_fetch_list: List[ 254 Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"] 255 ] = [] 256 self._event_fetch_ongoing = 0 257 event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) 258 259 # We define this sequence here so that it can be referenced from both 260 # the DataStore and PersistEventStore. 261 def get_chain_id_txn(txn: Cursor) -> int: 262 txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains") 263 return cast(Tuple[int], txn.fetchone())[0] 264 265 self.event_chain_id_gen = build_sequence_generator( 266 db_conn, 267 database.engine, 268 get_chain_id_txn, 269 "event_auth_chain_id", 270 table="event_auth_chains", 271 id_column="chain_id", 272 ) 273 274 def process_replication_rows( 275 self, 276 stream_name: str, 277 instance_name: str, 278 token: int, 279 rows: Iterable[Any], 280 ) -> None: 281 if stream_name == EventsStream.NAME: 282 self._stream_id_gen.advance(instance_name, token) 283 elif stream_name == BackfillStream.NAME: 284 self._backfill_id_gen.advance(instance_name, -token) 285 286 super().process_replication_rows(stream_name, instance_name, token, rows) 287 288 async def get_received_ts(self, event_id: str) -> Optional[int]: 289 """Get received_ts (when it was persisted) for the event. 290 291 Raises an exception for unknown events. 292 293 Args: 294 event_id: The event ID to query. 295 296 Returns: 297 Timestamp in milliseconds, or None for events that were persisted 298 before received_ts was implemented. 299 """ 300 return await self.db_pool.simple_select_one_onecol( 301 table="events", 302 keyvalues={"event_id": event_id}, 303 retcol="received_ts", 304 desc="get_received_ts", 305 ) 306 307 # Inform mypy that if allow_none is False (the default) then get_event 308 # always returns an EventBase. 309 @overload 310 async def get_event( 311 self, 312 event_id: str, 313 redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, 314 get_prev_content: bool = ..., 315 allow_rejected: bool = ..., 316 allow_none: Literal[False] = ..., 317 check_room_id: Optional[str] = ..., 318 ) -> EventBase: 319 ... 320 321 @overload 322 async def get_event( 323 self, 324 event_id: str, 325 redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, 326 get_prev_content: bool = ..., 327 allow_rejected: bool = ..., 328 allow_none: Literal[True] = ..., 329 check_room_id: Optional[str] = ..., 330 ) -> Optional[EventBase]: 331 ... 332 333 async def get_event( 334 self, 335 event_id: str, 336 redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, 337 get_prev_content: bool = False, 338 allow_rejected: bool = False, 339 allow_none: bool = False, 340 check_room_id: Optional[str] = None, 341 ) -> Optional[EventBase]: 342 """Get an event from the database by event_id. 343 344 Args: 345 event_id: The event_id of the event to fetch 346 347 redact_behaviour: Determine what to do with a redacted event. Possible values: 348 * AS_IS - Return the full event body with no redacted content 349 * REDACT - Return the event but with a redacted body 350 * DISALLOW - Do not return redacted events (behave as per allow_none 351 if the event is redacted) 352 353 get_prev_content: If True and event is a state event, 354 include the previous states content in the unsigned field. 355 356 allow_rejected: If True, return rejected events. Otherwise, 357 behave as per allow_none. 358 359 allow_none: If True, return None if no event found, if 360 False throw a NotFoundError 361 362 check_room_id: if not None, check the room of the found event. 363 If there is a mismatch, behave as per allow_none. 364 365 Returns: 366 The event, or None if the event was not found. 367 """ 368 if not isinstance(event_id, str): 369 raise TypeError("Invalid event event_id %r" % (event_id,)) 370 371 events = await self.get_events_as_list( 372 [event_id], 373 redact_behaviour=redact_behaviour, 374 get_prev_content=get_prev_content, 375 allow_rejected=allow_rejected, 376 ) 377 378 event = events[0] if events else None 379 380 if event is not None and check_room_id is not None: 381 if event.room_id != check_room_id: 382 event = None 383 384 if event is None and not allow_none: 385 raise NotFoundError("Could not find event %s" % (event_id,)) 386 387 return event 388 389 async def get_events( 390 self, 391 event_ids: Collection[str], 392 redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, 393 get_prev_content: bool = False, 394 allow_rejected: bool = False, 395 ) -> Dict[str, EventBase]: 396 """Get events from the database 397 398 Args: 399 event_ids: The event_ids of the events to fetch 400 401 redact_behaviour: Determine what to do with a redacted event. Possible 402 values: 403 * AS_IS - Return the full event body with no redacted content 404 * REDACT - Return the event but with a redacted body 405 * DISALLOW - Do not return redacted events (omit them from the response) 406 407 get_prev_content: If True and event is a state event, 408 include the previous states content in the unsigned field. 409 410 allow_rejected: If True, return rejected events. Otherwise, 411 omits rejeted events from the response. 412 413 Returns: 414 A mapping from event_id to event. 415 """ 416 events = await self.get_events_as_list( 417 event_ids, 418 redact_behaviour=redact_behaviour, 419 get_prev_content=get_prev_content, 420 allow_rejected=allow_rejected, 421 ) 422 423 return {e.event_id: e for e in events} 424 425 async def get_events_as_list( 426 self, 427 event_ids: Collection[str], 428 redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, 429 get_prev_content: bool = False, 430 allow_rejected: bool = False, 431 ) -> List[EventBase]: 432 """Get events from the database and return in a list in the same order 433 as given by `event_ids` arg. 434 435 Unknown events will be omitted from the response. 436 437 Args: 438 event_ids: The event_ids of the events to fetch 439 440 redact_behaviour: Determine what to do with a redacted event. Possible values: 441 * AS_IS - Return the full event body with no redacted content 442 * REDACT - Return the event but with a redacted body 443 * DISALLOW - Do not return redacted events (omit them from the response) 444 445 get_prev_content: If True and event is a state event, 446 include the previous states content in the unsigned field. 447 448 allow_rejected: If True, return rejected events. Otherwise, 449 omits rejected events from the response. 450 451 Returns: 452 List of events fetched from the database. The events are in the same 453 order as `event_ids` arg. 454 455 Note that the returned list may be smaller than the list of event 456 IDs if not all events could be fetched. 457 """ 458 459 if not event_ids: 460 return [] 461 462 # there may be duplicates so we cast the list to a set 463 event_entry_map = await self._get_events_from_cache_or_db( 464 set(event_ids), allow_rejected=allow_rejected 465 ) 466 467 events = [] 468 for event_id in event_ids: 469 entry = event_entry_map.get(event_id, None) 470 if not entry: 471 continue 472 473 if not allow_rejected: 474 assert not entry.event.rejected_reason, ( 475 "rejected event returned from _get_events_from_cache_or_db despite " 476 "allow_rejected=False" 477 ) 478 479 # We may not have had the original event when we received a redaction, so 480 # we have to recheck auth now. 481 482 if not allow_rejected and entry.event.type == EventTypes.Redaction: 483 if entry.event.redacts is None: 484 # A redacted redaction doesn't have a `redacts` key, in 485 # which case lets just withhold the event. 486 # 487 # Note: Most of the time if the redactions has been 488 # redacted we still have the un-redacted event in the DB 489 # and so we'll still see the `redacts` key. However, this 490 # isn't always true e.g. if we have censored the event. 491 logger.debug( 492 "Withholding redaction event %s as we don't have redacts key", 493 event_id, 494 ) 495 continue 496 497 redacted_event_id = entry.event.redacts 498 event_map = await self._get_events_from_cache_or_db([redacted_event_id]) 499 original_event_entry = event_map.get(redacted_event_id) 500 if not original_event_entry: 501 # we don't have the redacted event (or it was rejected). 502 # 503 # We assume that the redaction isn't authorized for now; if the 504 # redacted event later turns up, the redaction will be re-checked, 505 # and if it is found valid, the original will get redacted before it 506 # is served to the client. 507 logger.debug( 508 "Withholding redaction event %s since we don't (yet) have the " 509 "original %s", 510 event_id, 511 redacted_event_id, 512 ) 513 continue 514 515 original_event = original_event_entry.event 516 if original_event.type == EventTypes.Create: 517 # we never serve redactions of Creates to clients. 518 logger.info( 519 "Withholding redaction %s of create event %s", 520 event_id, 521 redacted_event_id, 522 ) 523 continue 524 525 if original_event.room_id != entry.event.room_id: 526 logger.info( 527 "Withholding redaction %s of event %s from a different room", 528 event_id, 529 redacted_event_id, 530 ) 531 continue 532 533 if entry.event.internal_metadata.need_to_check_redaction(): 534 original_domain = get_domain_from_id(original_event.sender) 535 redaction_domain = get_domain_from_id(entry.event.sender) 536 if original_domain != redaction_domain: 537 # the senders don't match, so this is forbidden 538 logger.info( 539 "Withholding redaction %s whose sender domain %s doesn't " 540 "match that of redacted event %s %s", 541 event_id, 542 redaction_domain, 543 redacted_event_id, 544 original_domain, 545 ) 546 continue 547 548 # Update the cache to save doing the checks again. 549 entry.event.internal_metadata.recheck_redaction = False 550 551 event = entry.event 552 553 if entry.redacted_event: 554 if redact_behaviour == EventRedactBehaviour.BLOCK: 555 # Skip this event 556 continue 557 elif redact_behaviour == EventRedactBehaviour.REDACT: 558 event = entry.redacted_event 559 560 events.append(event) 561 562 if get_prev_content: 563 if "replaces_state" in event.unsigned: 564 prev = await self.get_event( 565 event.unsigned["replaces_state"], 566 get_prev_content=False, 567 allow_none=True, 568 ) 569 if prev: 570 event.unsigned = dict(event.unsigned) 571 event.unsigned["prev_content"] = prev.content 572 event.unsigned["prev_sender"] = prev.sender 573 574 return events 575 576 async def _get_events_from_cache_or_db( 577 self, event_ids: Iterable[str], allow_rejected: bool = False 578 ) -> Dict[str, EventCacheEntry]: 579 """Fetch a bunch of events from the cache or the database. 580 581 If events are pulled from the database, they will be cached for future lookups. 582 583 Unknown events are omitted from the response. 584 585 Args: 586 587 event_ids: The event_ids of the events to fetch 588 589 allow_rejected: Whether to include rejected events. If False, 590 rejected events are omitted from the response. 591 592 Returns: 593 map from event id to result 594 """ 595 event_entry_map = self._get_events_from_cache( 596 event_ids, 597 ) 598 599 missing_events_ids = {e for e in event_ids if e not in event_entry_map} 600 601 # We now look up if we're already fetching some of the events in the DB, 602 # if so we wait for those lookups to finish instead of pulling the same 603 # events out of the DB multiple times. 604 # 605 # Note: we might get the same `ObservableDeferred` back for multiple 606 # events we're already fetching, so we deduplicate the deferreds to 607 # avoid extraneous work (if we don't do this we can end up in a n^2 mode 608 # when we wait on the same Deferred N times, then try and merge the 609 # same dict into itself N times). 610 already_fetching_ids: Set[str] = set() 611 already_fetching_deferreds: Set[ 612 ObservableDeferred[Dict[str, EventCacheEntry]] 613 ] = set() 614 615 for event_id in missing_events_ids: 616 deferred = self._current_event_fetches.get(event_id) 617 if deferred is not None: 618 # We're already pulling the event out of the DB. Add the deferred 619 # to the collection of deferreds to wait on. 620 already_fetching_ids.add(event_id) 621 already_fetching_deferreds.add(deferred) 622 623 missing_events_ids.difference_update(already_fetching_ids) 624 625 if missing_events_ids: 626 log_ctx = current_context() 627 log_ctx.record_event_fetch(len(missing_events_ids)) 628 629 # Add entries to `self._current_event_fetches` for each event we're 630 # going to pull from the DB. We use a single deferred that resolves 631 # to all the events we pulled from the DB (this will result in this 632 # function returning more events than requested, but that can happen 633 # already due to `_get_events_from_db`). 634 fetching_deferred: ObservableDeferred[ 635 Dict[str, EventCacheEntry] 636 ] = ObservableDeferred(defer.Deferred(), consumeErrors=True) 637 for event_id in missing_events_ids: 638 self._current_event_fetches[event_id] = fetching_deferred 639 640 # Note that _get_events_from_db is also responsible for turning db rows 641 # into FrozenEvents (via _get_event_from_row), which involves seeing if 642 # the events have been redacted, and if so pulling the redaction event out 643 # of the database to check it. 644 # 645 try: 646 missing_events = await self._get_events_from_db( 647 missing_events_ids, 648 ) 649 650 event_entry_map.update(missing_events) 651 except Exception as e: 652 with PreserveLoggingContext(): 653 fetching_deferred.errback(e) 654 raise e 655 finally: 656 # Ensure that we mark these events as no longer being fetched. 657 for event_id in missing_events_ids: 658 self._current_event_fetches.pop(event_id, None) 659 660 with PreserveLoggingContext(): 661 fetching_deferred.callback(missing_events) 662 663 if already_fetching_deferreds: 664 # Wait for the other event requests to finish and add their results 665 # to ours. 666 results = await make_deferred_yieldable( 667 defer.gatherResults( 668 (d.observe() for d in already_fetching_deferreds), 669 consumeErrors=True, 670 ) 671 ).addErrback(unwrapFirstError) 672 673 for result in results: 674 # We filter out events that we haven't asked for as we might get 675 # a *lot* of superfluous events back, and there is no point 676 # going through and inserting them all (which can take time). 677 event_entry_map.update( 678 (event_id, entry) 679 for event_id, entry in result.items() 680 if event_id in already_fetching_ids 681 ) 682 683 if not allow_rejected: 684 event_entry_map = { 685 event_id: entry 686 for event_id, entry in event_entry_map.items() 687 if not entry.event.rejected_reason 688 } 689 690 return event_entry_map 691 692 def _invalidate_get_event_cache(self, event_id: str) -> None: 693 self._get_event_cache.invalidate((event_id,)) 694 695 def _get_events_from_cache( 696 self, events: Iterable[str], update_metrics: bool = True 697 ) -> Dict[str, EventCacheEntry]: 698 """Fetch events from the caches. 699 700 May return rejected events. 701 702 Args: 703 events: list of event_ids to fetch 704 update_metrics: Whether to update the cache hit ratio metrics 705 """ 706 event_map = {} 707 708 for event_id in events: 709 ret = self._get_event_cache.get( 710 (event_id,), None, update_metrics=update_metrics 711 ) 712 if not ret: 713 continue 714 715 event_map[event_id] = ret 716 717 return event_map 718 719 async def get_stripped_room_state_from_event_context( 720 self, 721 context: EventContext, 722 state_types_to_include: Container[str], 723 membership_user_id: Optional[str] = None, 724 ) -> List[JsonDict]: 725 """ 726 Retrieve the stripped state from a room, given an event context to retrieve state 727 from as well as the state types to include. Optionally, include the membership 728 events from a specific user. 729 730 "Stripped" state means that only the `type`, `state_key`, `content` and `sender` keys 731 are included from each state event. 732 733 Args: 734 context: The event context to retrieve state of the room from. 735 state_types_to_include: The type of state events to include. 736 membership_user_id: An optional user ID to include the stripped membership state 737 events of. This is useful when generating the stripped state of a room for 738 invites. We want to send membership events of the inviter, so that the 739 invitee can display the inviter's profile information if the room lacks any. 740 741 Returns: 742 A list of dictionaries, each representing a stripped state event from the room. 743 """ 744 current_state_ids = await context.get_current_state_ids() 745 746 # We know this event is not an outlier, so this must be 747 # non-None. 748 assert current_state_ids is not None 749 750 # The state to include 751 state_to_include_ids = [ 752 e_id 753 for k, e_id in current_state_ids.items() 754 if k[0] in state_types_to_include 755 or (membership_user_id and k == (EventTypes.Member, membership_user_id)) 756 ] 757 758 state_to_include = await self.get_events(state_to_include_ids) 759 760 return [ 761 { 762 "type": e.type, 763 "state_key": e.state_key, 764 "content": e.content, 765 "sender": e.sender, 766 } 767 for e in state_to_include.values() 768 ] 769 770 def _maybe_start_fetch_thread(self) -> None: 771 """Starts an event fetch thread if we are not yet at the maximum number.""" 772 with self._event_fetch_lock: 773 if ( 774 self._event_fetch_list 775 and self._event_fetch_ongoing < EVENT_QUEUE_THREADS 776 ): 777 self._event_fetch_ongoing += 1 778 event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) 779 # `_event_fetch_ongoing` is decremented in `_fetch_thread`. 780 should_start = True 781 else: 782 should_start = False 783 784 if should_start: 785 run_as_background_process("fetch_events", self._fetch_thread) 786 787 async def _fetch_thread(self) -> None: 788 """Services requests for events from `_event_fetch_list`.""" 789 exc = None 790 try: 791 await self.db_pool.runWithConnection(self._fetch_loop) 792 except BaseException as e: 793 exc = e 794 raise 795 finally: 796 should_restart = False 797 event_fetches_to_fail = [] 798 with self._event_fetch_lock: 799 self._event_fetch_ongoing -= 1 800 event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) 801 802 # There may still be work remaining in `_event_fetch_list` if we 803 # failed, or it was added in between us deciding to exit and 804 # decrementing `_event_fetch_ongoing`. 805 if self._event_fetch_list: 806 if exc is None: 807 # We decided to exit, but then some more work was added 808 # before `_event_fetch_ongoing` was decremented. 809 # If a new event fetch thread was not started, we should 810 # restart ourselves since the remaining event fetch threads 811 # may take a while to get around to the new work. 812 # 813 # Unfortunately it is not possible to tell whether a new 814 # event fetch thread was started, so we restart 815 # unconditionally. If we are unlucky, we will end up with 816 # an idle fetch thread, but it will time out after 817 # `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds 818 # in any case. 819 # 820 # Note that multiple fetch threads may run down this path at 821 # the same time. 822 should_restart = True 823 elif isinstance(exc, Exception): 824 if self._event_fetch_ongoing == 0: 825 # We were the last remaining fetcher and failed. 826 # Fail any outstanding fetches since no one else will 827 # handle them. 828 event_fetches_to_fail = self._event_fetch_list 829 self._event_fetch_list = [] 830 else: 831 # We weren't the last remaining fetcher, so another 832 # fetcher will pick up the work. This will either happen 833 # after their existing work, however long that takes, 834 # or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if 835 # they are idle. 836 pass 837 else: 838 # The exception is a `SystemExit`, `KeyboardInterrupt` or 839 # `GeneratorExit`. Don't try to do anything clever here. 840 pass 841 842 if should_restart: 843 # We exited cleanly but noticed more work. 844 self._maybe_start_fetch_thread() 845 846 if event_fetches_to_fail: 847 # We were the last remaining fetcher and failed. 848 # Fail any outstanding fetches since no one else will handle them. 849 assert exc is not None 850 with PreserveLoggingContext(): 851 for _, deferred in event_fetches_to_fail: 852 deferred.errback(exc) 853 854 def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None: 855 """Takes a database connection and waits for requests for events from 856 the _event_fetch_list queue. 857 """ 858 i = 0 859 while True: 860 with self._event_fetch_lock: 861 event_list = self._event_fetch_list 862 self._event_fetch_list = [] 863 864 if not event_list: 865 # There are no requests waiting. If we haven't yet reached the 866 # maximum iteration limit, wait for some more requests to turn up. 867 # Otherwise, bail out. 868 single_threaded = self.database_engine.single_threaded 869 if ( 870 not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING 871 or single_threaded 872 or i > EVENT_QUEUE_ITERATIONS 873 ): 874 return 875 876 self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) 877 i += 1 878 continue 879 i = 0 880 881 self._fetch_event_list(conn, event_list) 882 883 def _fetch_event_list( 884 self, 885 conn: LoggingDatabaseConnection, 886 event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]], 887 ) -> None: 888 """Handle a load of requests from the _event_fetch_list queue 889 890 Args: 891 conn: database connection 892 893 event_list: 894 The fetch requests. Each entry consists of a list of event 895 ids to be fetched, and a deferred to be completed once the 896 events have been fetched. 897 898 The deferreds are callbacked with a dictionary mapping from event id 899 to event row. Note that it may well contain additional events that 900 were not part of this request. 901 """ 902 with Measure(self._clock, "_fetch_event_list"): 903 try: 904 events_to_fetch = { 905 event_id for events, _ in event_list for event_id in events 906 } 907 908 row_dict = self.db_pool.new_transaction( 909 conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch 910 ) 911 912 # We only want to resolve deferreds from the main thread 913 def fire() -> None: 914 for _, d in event_list: 915 d.callback(row_dict) 916 917 with PreserveLoggingContext(): 918 self.hs.get_reactor().callFromThread(fire) 919 except Exception as e: 920 logger.exception("do_fetch") 921 922 # We only want to resolve deferreds from the main thread 923 def fire_errback(exc: Exception) -> None: 924 for _, d in event_list: 925 d.errback(exc) 926 927 with PreserveLoggingContext(): 928 self.hs.get_reactor().callFromThread(fire_errback, e) 929 930 async def _get_events_from_db( 931 self, event_ids: Collection[str] 932 ) -> Dict[str, EventCacheEntry]: 933 """Fetch a bunch of events from the database. 934 935 May return rejected events. 936 937 Returned events will be added to the cache for future lookups. 938 939 Unknown events are omitted from the response. 940 941 Args: 942 event_ids: The event_ids of the events to fetch 943 944 Returns: 945 map from event id to result. May return extra events which 946 weren't asked for. 947 """ 948 fetched_event_ids: Set[str] = set() 949 fetched_events: Dict[str, _EventRow] = {} 950 events_to_fetch = event_ids 951 952 while events_to_fetch: 953 row_map = await self._enqueue_events(events_to_fetch) 954 955 # we need to recursively fetch any redactions of those events 956 redaction_ids: Set[str] = set() 957 for event_id in events_to_fetch: 958 row = row_map.get(event_id) 959 fetched_event_ids.add(event_id) 960 if row: 961 fetched_events[event_id] = row 962 redaction_ids.update(row.redactions) 963 964 events_to_fetch = redaction_ids.difference(fetched_event_ids) 965 if events_to_fetch: 966 logger.debug("Also fetching redaction events %s", events_to_fetch) 967 968 # build a map from event_id to EventBase 969 event_map: Dict[str, EventBase] = {} 970 for event_id, row in fetched_events.items(): 971 assert row.event_id == event_id 972 973 rejected_reason = row.rejected_reason 974 975 # If the event or metadata cannot be parsed, log the error and act 976 # as if the event is unknown. 977 try: 978 d = db_to_json(row.json) 979 except ValueError: 980 logger.error("Unable to parse json from event: %s", event_id) 981 continue 982 try: 983 internal_metadata = db_to_json(row.internal_metadata) 984 except ValueError: 985 logger.error( 986 "Unable to parse internal_metadata from event: %s", event_id 987 ) 988 continue 989 990 format_version = row.format_version 991 if format_version is None: 992 # This means that we stored the event before we had the concept 993 # of a event format version, so it must be a V1 event. 994 format_version = EventFormatVersions.V1 995 996 room_version_id = row.room_version_id 997 998 room_version: Optional[RoomVersion] 999 if not room_version_id: 1000 # this should only happen for out-of-band membership events which 1001 # arrived before #6983 landed. For all other events, we should have 1002 # an entry in the 'rooms' table. 1003 # 1004 # However, the 'out_of_band_membership' flag is unreliable for older 1005 # invites, so just accept it for all membership events. 1006 # 1007 if d["type"] != EventTypes.Member: 1008 raise Exception( 1009 "Room %s for event %s is unknown" % (d["room_id"], event_id) 1010 ) 1011 1012 # so, assuming this is an out-of-band-invite that arrived before #6983 1013 # landed, we know that the room version must be v5 or earlier (because 1014 # v6 hadn't been invented at that point, so invites from such rooms 1015 # would have been rejected.) 1016 # 1017 # The main reason we need to know the room version here (other than 1018 # choosing the right python Event class) is in case the event later has 1019 # to be redacted - and all the room versions up to v5 used the same 1020 # redaction algorithm. 1021 # 1022 # So, the following approximations should be adequate. 1023 1024 if format_version == EventFormatVersions.V1: 1025 # if it's event format v1 then it must be room v1 or v2 1026 room_version = RoomVersions.V1 1027 elif format_version == EventFormatVersions.V2: 1028 # if it's event format v2 then it must be room v3 1029 room_version = RoomVersions.V3 1030 else: 1031 # if it's event format v3 then it must be room v4 or v5 1032 room_version = RoomVersions.V5 1033 else: 1034 room_version = KNOWN_ROOM_VERSIONS.get(room_version_id) 1035 if not room_version: 1036 logger.warning( 1037 "Event %s in room %s has unknown room version %s", 1038 event_id, 1039 d["room_id"], 1040 room_version_id, 1041 ) 1042 continue 1043 1044 if room_version.event_format != format_version: 1045 logger.error( 1046 "Event %s in room %s with version %s has wrong format: " 1047 "expected %s, was %s", 1048 event_id, 1049 d["room_id"], 1050 room_version_id, 1051 room_version.event_format, 1052 format_version, 1053 ) 1054 continue 1055 1056 original_ev = make_event_from_dict( 1057 event_dict=d, 1058 room_version=room_version, 1059 internal_metadata_dict=internal_metadata, 1060 rejected_reason=rejected_reason, 1061 ) 1062 original_ev.internal_metadata.stream_ordering = row.stream_ordering 1063 original_ev.internal_metadata.outlier = row.outlier 1064 1065 event_map[event_id] = original_ev 1066 1067 # finally, we can decide whether each one needs redacting, and build 1068 # the cache entries. 1069 result_map: Dict[str, EventCacheEntry] = {} 1070 for event_id, original_ev in event_map.items(): 1071 redactions = fetched_events[event_id].redactions 1072 redacted_event = self._maybe_redact_event_row( 1073 original_ev, redactions, event_map 1074 ) 1075 1076 cache_entry = EventCacheEntry( 1077 event=original_ev, redacted_event=redacted_event 1078 ) 1079 1080 self._get_event_cache.set((event_id,), cache_entry) 1081 result_map[event_id] = cache_entry 1082 1083 return result_map 1084 1085 async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]: 1086 """Fetches events from the database using the _event_fetch_list. This 1087 allows batch and bulk fetching of events - it allows us to fetch events 1088 without having to create a new transaction for each request for events. 1089 1090 Args: 1091 events: events to be fetched. 1092 1093 Returns: 1094 A map from event id to row data from the database. May contain events 1095 that weren't requested. 1096 """ 1097 1098 events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred() 1099 with self._event_fetch_lock: 1100 self._event_fetch_list.append((events, events_d)) 1101 self._event_fetch_lock.notify() 1102 1103 self._maybe_start_fetch_thread() 1104 1105 logger.debug("Loading %d events: %s", len(events), events) 1106 with PreserveLoggingContext(): 1107 row_map = await events_d 1108 logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) 1109 1110 return row_map 1111 1112 def _fetch_event_rows( 1113 self, txn: LoggingTransaction, event_ids: Iterable[str] 1114 ) -> Dict[str, _EventRow]: 1115 """Fetch event rows from the database 1116 1117 Events which are not found are omitted from the result. 1118 1119 Args: 1120 txn: The database transaction. 1121 event_ids: event IDs to fetch 1122 1123 Returns: 1124 A map from event id to event info. 1125 """ 1126 event_dict = {} 1127 for evs in batch_iter(event_ids, 200): 1128 sql = """\ 1129 SELECT 1130 e.event_id, 1131 e.stream_ordering, 1132 ej.internal_metadata, 1133 ej.json, 1134 ej.format_version, 1135 r.room_version, 1136 rej.reason, 1137 e.outlier 1138 FROM events AS e 1139 JOIN event_json AS ej USING (event_id) 1140 LEFT JOIN rooms r ON r.room_id = e.room_id 1141 LEFT JOIN rejections as rej USING (event_id) 1142 WHERE """ 1143 1144 clause, args = make_in_list_sql_clause( 1145 txn.database_engine, "e.event_id", evs 1146 ) 1147 1148 txn.execute(sql + clause, args) 1149 1150 for row in txn: 1151 event_id = row[0] 1152 event_dict[event_id] = _EventRow( 1153 event_id=event_id, 1154 stream_ordering=row[1], 1155 internal_metadata=row[2], 1156 json=row[3], 1157 format_version=row[4], 1158 room_version_id=row[5], 1159 rejected_reason=row[6], 1160 redactions=[], 1161 outlier=row[7], 1162 ) 1163 1164 # check for redactions 1165 redactions_sql = "SELECT event_id, redacts FROM redactions WHERE " 1166 1167 clause, args = make_in_list_sql_clause(txn.database_engine, "redacts", evs) 1168 1169 txn.execute(redactions_sql + clause, args) 1170 1171 for (redacter, redacted) in txn: 1172 d = event_dict.get(redacted) 1173 if d: 1174 d.redactions.append(redacter) 1175 1176 return event_dict 1177 1178 def _maybe_redact_event_row( 1179 self, 1180 original_ev: EventBase, 1181 redactions: Iterable[str], 1182 event_map: Dict[str, EventBase], 1183 ) -> Optional[EventBase]: 1184 """Given an event object and a list of possible redacting event ids, 1185 determine whether to honour any of those redactions and if so return a redacted 1186 event. 1187 1188 Args: 1189 original_ev: The original event. 1190 redactions: list of event ids of potential redaction events 1191 event_map: other events which have been fetched, in which we can 1192 look up the redaaction events. Map from event id to event. 1193 1194 Returns: 1195 If the event should be redacted, a pruned event object. Otherwise, None. 1196 """ 1197 if original_ev.type == "m.room.create": 1198 # we choose to ignore redactions of m.room.create events. 1199 return None 1200 1201 for redaction_id in redactions: 1202 redaction_event = event_map.get(redaction_id) 1203 if not redaction_event or redaction_event.rejected_reason: 1204 # we don't have the redaction event, or the redaction event was not 1205 # authorized. 1206 logger.debug( 1207 "%s was redacted by %s but redaction not found/authed", 1208 original_ev.event_id, 1209 redaction_id, 1210 ) 1211 continue 1212 1213 if redaction_event.room_id != original_ev.room_id: 1214 logger.debug( 1215 "%s was redacted by %s but redaction was in a different room!", 1216 original_ev.event_id, 1217 redaction_id, 1218 ) 1219 continue 1220 1221 # Starting in room version v3, some redactions need to be 1222 # rechecked if we didn't have the redacted event at the 1223 # time, so we recheck on read instead. 1224 if redaction_event.internal_metadata.need_to_check_redaction(): 1225 expected_domain = get_domain_from_id(original_ev.sender) 1226 if get_domain_from_id(redaction_event.sender) == expected_domain: 1227 # This redaction event is allowed. Mark as not needing a recheck. 1228 redaction_event.internal_metadata.recheck_redaction = False 1229 else: 1230 # Senders don't match, so the event isn't actually redacted 1231 logger.debug( 1232 "%s was redacted by %s but the senders don't match", 1233 original_ev.event_id, 1234 redaction_id, 1235 ) 1236 continue 1237 1238 logger.debug("Redacting %s due to %s", original_ev.event_id, redaction_id) 1239 1240 # we found a good redaction event. Redact! 1241 redacted_event = prune_event(original_ev) 1242 redacted_event.unsigned["redacted_by"] = redaction_id 1243 1244 # It's fine to add the event directly, since get_pdu_json 1245 # will serialise this field correctly 1246 redacted_event.unsigned["redacted_because"] = redaction_event 1247 1248 return redacted_event 1249 1250 # no valid redaction found for this event 1251 return None 1252 1253 async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]: 1254 """Given a list of event ids, check if we have already processed and 1255 stored them as non outliers. 1256 """ 1257 rows = await self.db_pool.simple_select_many_batch( 1258 table="events", 1259 retcols=("event_id",), 1260 column="event_id", 1261 iterable=list(event_ids), 1262 keyvalues={"outlier": False}, 1263 desc="have_events_in_timeline", 1264 ) 1265 1266 return {r["event_id"] for r in rows} 1267 1268 async def have_seen_events( 1269 self, room_id: str, event_ids: Iterable[str] 1270 ) -> Set[str]: 1271 """Given a list of event ids, check if we have already processed them. 1272 1273 The room_id is only used to structure the cache (so that it can later be 1274 invalidated by room_id) - there is no guarantee that the events are actually 1275 in the room in question. 1276 1277 Args: 1278 room_id: Room we are polling 1279 event_ids: events we are looking for 1280 1281 Returns: 1282 The set of events we have already seen. 1283 """ 1284 res = await self._have_seen_events_dict( 1285 (room_id, event_id) for event_id in event_ids 1286 ) 1287 return {eid for ((_rid, eid), have_event) in res.items() if have_event} 1288 1289 @cachedList("have_seen_event", "keys") 1290 async def _have_seen_events_dict( 1291 self, keys: Iterable[Tuple[str, str]] 1292 ) -> Dict[Tuple[str, str], bool]: 1293 """Helper for have_seen_events 1294 1295 Returns: 1296 a dict {(room_id, event_id)-> bool} 1297 """ 1298 # if the event cache contains the event, obviously we've seen it. 1299 1300 cache_results = { 1301 (rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,)) 1302 } 1303 results = {x: True for x in cache_results} 1304 1305 def have_seen_events_txn( 1306 txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...] 1307 ) -> None: 1308 # we deliberately do *not* query the database for room_id, to make the 1309 # query an index-only lookup on `events_event_id_key`. 1310 # 1311 # We therefore pull the events from the database into a set... 1312 1313 sql = "SELECT event_id FROM events AS e WHERE " 1314 clause, args = make_in_list_sql_clause( 1315 txn.database_engine, "e.event_id", [eid for (_rid, eid) in chunk] 1316 ) 1317 txn.execute(sql + clause, args) 1318 found_events = {eid for eid, in txn} 1319 1320 # ... and then we can update the results for each row in the batch 1321 results.update({(rid, eid): (eid in found_events) for (rid, eid) in chunk}) 1322 1323 # each batch requires its own index scan, so we make the batches as big as 1324 # possible. 1325 for chunk in batch_iter((k for k in keys if k not in cache_results), 500): 1326 await self.db_pool.runInteraction( 1327 "have_seen_events", have_seen_events_txn, chunk 1328 ) 1329 1330 return results 1331 1332 @cached(max_entries=100000, tree=True) 1333 async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn: 1334 # this only exists for the benefit of the @cachedList descriptor on 1335 # _have_seen_events_dict 1336 raise NotImplementedError() 1337 1338 def _get_current_state_event_counts_txn( 1339 self, txn: LoggingTransaction, room_id: str 1340 ) -> int: 1341 """ 1342 See get_current_state_event_counts. 1343 """ 1344 sql = "SELECT COUNT(*) FROM current_state_events WHERE room_id=?" 1345 txn.execute(sql, (room_id,)) 1346 row = txn.fetchone() 1347 return row[0] if row else 0 1348 1349 async def get_current_state_event_counts(self, room_id: str) -> int: 1350 """ 1351 Gets the current number of state events in a room. 1352 1353 Args: 1354 room_id: The room ID to query. 1355 1356 Returns: 1357 The current number of state events. 1358 """ 1359 return await self.db_pool.runInteraction( 1360 "get_current_state_event_counts", 1361 self._get_current_state_event_counts_txn, 1362 room_id, 1363 ) 1364 1365 async def get_room_complexity(self, room_id: str) -> Dict[str, float]: 1366 """ 1367 Get a rough approximation of the complexity of the room. This is used by 1368 remote servers to decide whether they wish to join the room or not. 1369 Higher complexity value indicates that being in the room will consume 1370 more resources. 1371 1372 Args: 1373 room_id: The room ID to query. 1374 1375 Returns: 1376 dict[str:float] of complexity version to complexity. 1377 """ 1378 state_events = await self.get_current_state_event_counts(room_id) 1379 1380 # Call this one "v1", so we can introduce new ones as we want to develop 1381 # it. 1382 complexity_v1 = round(state_events / 500, 2) 1383 1384 return {"v1": complexity_v1} 1385 1386 async def get_all_new_forward_event_rows( 1387 self, instance_name: str, last_id: int, current_id: int, limit: int 1388 ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: 1389 """Returns new events, for the Events replication stream 1390 1391 Args: 1392 last_id: the last stream_id from the previous batch. 1393 current_id: the maximum stream_id to return up to 1394 limit: the maximum number of rows to return 1395 1396 Returns: 1397 a list of events stream rows. Each tuple consists of a stream id as 1398 the first element, followed by fields suitable for casting into an 1399 EventsStreamRow. 1400 """ 1401 1402 def get_all_new_forward_event_rows( 1403 txn: LoggingTransaction, 1404 ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: 1405 sql = ( 1406 "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," 1407 " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" 1408 " FROM events AS e" 1409 " LEFT JOIN redactions USING (event_id)" 1410 " LEFT JOIN state_events AS se USING (event_id)" 1411 " LEFT JOIN event_relations USING (event_id)" 1412 " LEFT JOIN room_memberships USING (event_id)" 1413 " LEFT JOIN rejections USING (event_id)" 1414 " WHERE ? < stream_ordering AND stream_ordering <= ?" 1415 " AND instance_name = ?" 1416 " ORDER BY stream_ordering ASC" 1417 " LIMIT ?" 1418 ) 1419 txn.execute(sql, (last_id, current_id, instance_name, limit)) 1420 return cast( 1421 List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() 1422 ) 1423 1424 return await self.db_pool.runInteraction( 1425 "get_all_new_forward_event_rows", get_all_new_forward_event_rows 1426 ) 1427 1428 async def get_ex_outlier_stream_rows( 1429 self, instance_name: str, last_id: int, current_id: int 1430 ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: 1431 """Returns de-outliered events, for the Events replication stream 1432 1433 Args: 1434 last_id: the last stream_id from the previous batch. 1435 current_id: the maximum stream_id to return up to 1436 1437 Returns: 1438 a list of events stream rows. Each tuple consists of a stream id as 1439 the first element, followed by fields suitable for casting into an 1440 EventsStreamRow. 1441 """ 1442 1443 def get_ex_outlier_stream_rows_txn( 1444 txn: LoggingTransaction, 1445 ) -> List[Tuple[int, str, str, str, str, str, str, str, str]]: 1446 sql = ( 1447 "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," 1448 " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL" 1449 " FROM events AS e" 1450 " INNER JOIN ex_outlier_stream AS out USING (event_id)" 1451 " LEFT JOIN redactions USING (event_id)" 1452 " LEFT JOIN state_events AS se USING (event_id)" 1453 " LEFT JOIN event_relations USING (event_id)" 1454 " LEFT JOIN room_memberships USING (event_id)" 1455 " LEFT JOIN rejections USING (event_id)" 1456 " WHERE ? < event_stream_ordering" 1457 " AND event_stream_ordering <= ?" 1458 " AND out.instance_name = ?" 1459 " ORDER BY event_stream_ordering ASC" 1460 ) 1461 1462 txn.execute(sql, (last_id, current_id, instance_name)) 1463 return cast( 1464 List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall() 1465 ) 1466 1467 return await self.db_pool.runInteraction( 1468 "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn 1469 ) 1470 1471 async def get_all_new_backfill_event_rows( 1472 self, instance_name: str, last_id: int, current_id: int, limit: int 1473 ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]: 1474 """Get updates for backfill replication stream, including all new 1475 backfilled events and events that have gone from being outliers to not. 1476 1477 NOTE: The IDs given here are from replication, and so should be 1478 *positive*. 1479 1480 Args: 1481 instance_name: The writer we want to fetch updates from. Unused 1482 here since there is only ever one writer. 1483 last_id: The token to fetch updates from. Exclusive. 1484 current_id: The token to fetch updates up to. Inclusive. 1485 limit: The requested limit for the number of rows to return. The 1486 function may return more or fewer rows. 1487 1488 Returns: 1489 A tuple consisting of: the updates, a token to use to fetch 1490 subsequent updates, and whether we returned fewer rows than exists 1491 between the requested tokens due to the limit. 1492 1493 The token returned can be used in a subsequent call to this 1494 function to get further updatees. 1495 1496 The updates are a list of 2-tuples of stream ID and the row data 1497 """ 1498 if last_id == current_id: 1499 return [], current_id, False 1500 1501 def get_all_new_backfill_event_rows( 1502 txn: LoggingTransaction, 1503 ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]: 1504 sql = ( 1505 "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," 1506 " se.state_key, redacts, relates_to_id" 1507 " FROM events AS e" 1508 " LEFT JOIN redactions USING (event_id)" 1509 " LEFT JOIN state_events AS se USING (event_id)" 1510 " LEFT JOIN event_relations USING (event_id)" 1511 " WHERE ? > stream_ordering AND stream_ordering >= ?" 1512 " AND instance_name = ?" 1513 " ORDER BY stream_ordering ASC" 1514 " LIMIT ?" 1515 ) 1516 txn.execute(sql, (-last_id, -current_id, instance_name, limit)) 1517 new_event_updates: List[ 1518 Tuple[int, Tuple[str, str, str, str, str, str]] 1519 ] = [] 1520 row: Tuple[int, str, str, str, str, str, str] 1521 # Type safety: iterating over `txn` yields `Tuple`, i.e. 1522 # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a 1523 # variadic tuple to a fixed length tuple and flags it up as an error. 1524 for row in txn: # type: ignore[assignment] 1525 new_event_updates.append((row[0], row[1:])) 1526 1527 limited = False 1528 if len(new_event_updates) == limit: 1529 upper_bound = new_event_updates[-1][0] 1530 limited = True 1531 else: 1532 upper_bound = current_id 1533 1534 sql = ( 1535 "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," 1536 " se.state_key, redacts, relates_to_id" 1537 " FROM events AS e" 1538 " INNER JOIN ex_outlier_stream AS out USING (event_id)" 1539 " LEFT JOIN redactions USING (event_id)" 1540 " LEFT JOIN state_events AS se USING (event_id)" 1541 " LEFT JOIN event_relations USING (event_id)" 1542 " WHERE ? > event_stream_ordering" 1543 " AND event_stream_ordering >= ?" 1544 " AND out.instance_name = ?" 1545 " ORDER BY event_stream_ordering DESC" 1546 ) 1547 txn.execute(sql, (-last_id, -upper_bound, instance_name)) 1548 # Type safety: iterating over `txn` yields `Tuple`, i.e. 1549 # `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a 1550 # variadic tuple to a fixed length tuple and flags it up as an error. 1551 for row in txn: # type: ignore[assignment] 1552 new_event_updates.append((row[0], row[1:])) 1553 1554 if len(new_event_updates) >= limit: 1555 upper_bound = new_event_updates[-1][0] 1556 limited = True 1557 1558 return new_event_updates, upper_bound, limited 1559 1560 return await self.db_pool.runInteraction( 1561 "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows 1562 ) 1563 1564 async def get_all_updated_current_state_deltas( 1565 self, instance_name: str, from_token: int, to_token: int, target_row_count: int 1566 ) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]: 1567 """Fetch updates from current_state_delta_stream 1568 1569 Args: 1570 from_token: The previous stream token. Updates from this stream id will 1571 be excluded. 1572 1573 to_token: The current stream token (ie the upper limit). Updates up to this 1574 stream id will be included (modulo the 'limit' param) 1575 1576 target_row_count: The number of rows to try to return. If more rows are 1577 available, we will set 'limited' in the result. In the event of a large 1578 batch, we may return more rows than this. 1579 Returns: 1580 A triplet `(updates, new_last_token, limited)`, where: 1581 * `updates` is a list of database tuples. 1582 * `new_last_token` is the new position in stream. 1583 * `limited` is whether there are more updates to fetch. 1584 """ 1585 1586 def get_all_updated_current_state_deltas_txn( 1587 txn: LoggingTransaction, 1588 ) -> List[Tuple[int, str, str, str, str]]: 1589 sql = """ 1590 SELECT stream_id, room_id, type, state_key, event_id 1591 FROM current_state_delta_stream 1592 WHERE ? < stream_id AND stream_id <= ? 1593 AND instance_name = ? 1594 ORDER BY stream_id ASC LIMIT ? 1595 """ 1596 txn.execute(sql, (from_token, to_token, instance_name, target_row_count)) 1597 return cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) 1598 1599 def get_deltas_for_stream_id_txn( 1600 txn: LoggingTransaction, stream_id: int 1601 ) -> List[Tuple[int, str, str, str, str]]: 1602 sql = """ 1603 SELECT stream_id, room_id, type, state_key, event_id 1604 FROM current_state_delta_stream 1605 WHERE stream_id = ? 1606 """ 1607 txn.execute(sql, [stream_id]) 1608 return cast(List[Tuple[int, str, str, str, str]], txn.fetchall()) 1609 1610 # we need to make sure that, for every stream id in the results, we get *all* 1611 # the rows with that stream id. 1612 1613 rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction( 1614 "get_all_updated_current_state_deltas", 1615 get_all_updated_current_state_deltas_txn, 1616 ) 1617 1618 # if we've got fewer rows than the limit, we're good 1619 if len(rows) < target_row_count: 1620 return rows, to_token, False 1621 1622 # we hit the limit, so reduce the upper limit so that we exclude the stream id 1623 # of the last row in the result. 1624 assert rows[-1][0] <= to_token 1625 to_token = rows[-1][0] - 1 1626 1627 # search backwards through the list for the point to truncate 1628 for idx in range(len(rows) - 1, 0, -1): 1629 if rows[idx - 1][0] <= to_token: 1630 return rows[:idx], to_token, True 1631 1632 # bother. We didn't get a full set of changes for even a single 1633 # stream id. let's run the query again, without a row limit, but for 1634 # just one stream id. 1635 to_token += 1 1636 rows = await self.db_pool.runInteraction( 1637 "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token 1638 ) 1639 1640 return rows, to_token, True 1641 1642 async def is_event_after(self, event_id1: str, event_id2: str) -> bool: 1643 """Returns True if event_id1 is after event_id2 in the stream""" 1644 to_1, so_1 = await self.get_event_ordering(event_id1) 1645 to_2, so_2 = await self.get_event_ordering(event_id2) 1646 return (to_1, so_1) > (to_2, so_2) 1647 1648 @cached(max_entries=5000) 1649 async def get_event_ordering(self, event_id: str) -> Tuple[int, int]: 1650 res = await self.db_pool.simple_select_one( 1651 table="events", 1652 retcols=["topological_ordering", "stream_ordering"], 1653 keyvalues={"event_id": event_id}, 1654 allow_none=True, 1655 ) 1656 1657 if not res: 1658 raise SynapseError(404, "Could not find event %s" % (event_id,)) 1659 1660 return int(res["topological_ordering"]), int(res["stream_ordering"]) 1661 1662 async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]: 1663 """Retrieve the entry with the lowest expiry timestamp in the event_expiry 1664 table, or None if there's no more event to expire. 1665 1666 Returns: 1667 A tuple containing the event ID as its first element and an expiry timestamp 1668 as its second one, if there's at least one row in the event_expiry table. 1669 None otherwise. 1670 """ 1671 1672 def get_next_event_to_expire_txn( 1673 txn: LoggingTransaction, 1674 ) -> Optional[Tuple[str, int]]: 1675 txn.execute( 1676 """ 1677 SELECT event_id, expiry_ts FROM event_expiry 1678 ORDER BY expiry_ts ASC LIMIT 1 1679 """ 1680 ) 1681 1682 return cast(Optional[Tuple[str, int]], txn.fetchone()) 1683 1684 return await self.db_pool.runInteraction( 1685 desc="get_next_event_to_expire", func=get_next_event_to_expire_txn 1686 ) 1687 1688 async def get_event_id_from_transaction_id( 1689 self, room_id: str, user_id: str, token_id: int, txn_id: str 1690 ) -> Optional[str]: 1691 """Look up if we have already persisted an event for the transaction ID, 1692 returning the event ID if so. 1693 """ 1694 return await self.db_pool.simple_select_one_onecol( 1695 table="event_txn_id", 1696 keyvalues={ 1697 "room_id": room_id, 1698 "user_id": user_id, 1699 "token_id": token_id, 1700 "txn_id": txn_id, 1701 }, 1702 retcol="event_id", 1703 allow_none=True, 1704 desc="get_event_id_from_transaction_id", 1705 ) 1706 1707 async def get_already_persisted_events( 1708 self, events: Iterable[EventBase] 1709 ) -> Dict[str, str]: 1710 """Look up if we have already persisted an event for the transaction ID, 1711 returning a mapping from event ID in the given list to the event ID of 1712 an existing event. 1713 1714 Also checks if there are duplicates in the given events, if there are 1715 will map duplicates to the *first* event. 1716 """ 1717 1718 mapping = {} 1719 txn_id_to_event: Dict[Tuple[str, int, str], str] = {} 1720 1721 for event in events: 1722 token_id = getattr(event.internal_metadata, "token_id", None) 1723 txn_id = getattr(event.internal_metadata, "txn_id", None) 1724 1725 if token_id and txn_id: 1726 # Check if this is a duplicate of an event in the given events. 1727 existing = txn_id_to_event.get((event.room_id, token_id, txn_id)) 1728 if existing: 1729 mapping[event.event_id] = existing 1730 continue 1731 1732 # Check if this is a duplicate of an event we've already 1733 # persisted. 1734 existing = await self.get_event_id_from_transaction_id( 1735 event.room_id, event.sender, token_id, txn_id 1736 ) 1737 if existing: 1738 mapping[event.event_id] = existing 1739 txn_id_to_event[(event.room_id, token_id, txn_id)] = existing 1740 else: 1741 txn_id_to_event[(event.room_id, token_id, txn_id)] = event.event_id 1742 1743 return mapping 1744 1745 @wrap_as_background_process("_cleanup_old_transaction_ids") 1746 async def _cleanup_old_transaction_ids(self) -> None: 1747 """Cleans out transaction id mappings older than 24hrs.""" 1748 1749 def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None: 1750 sql = """ 1751 DELETE FROM event_txn_id 1752 WHERE inserted_ts < ? 1753 """ 1754 one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000 1755 txn.execute(sql, (one_day_ago,)) 1756 1757 return await self.db_pool.runInteraction( 1758 "_cleanup_old_transaction_ids", 1759 _cleanup_old_transaction_ids_txn, 1760 ) 1761 1762 async def is_event_next_to_backward_gap(self, event: EventBase) -> bool: 1763 """Check if the given event is next to a backward gap of missing events. 1764 <latest messages> A(False)--->B(False)--->C(True)---> <gap, unknown events> <oldest messages> 1765 1766 Args: 1767 room_id: room where the event lives 1768 event_id: event to check 1769 1770 Returns: 1771 Boolean indicating whether it's an extremity 1772 """ 1773 1774 def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool: 1775 # If the event in question has any of its prev_events listed as a 1776 # backward extremity, it's next to a gap. 1777 # 1778 # We can't just check the backward edges in `event_edges` because 1779 # when we persist events, we will also record the prev_events as 1780 # edges to the event in question regardless of whether we have those 1781 # prev_events yet. We need to check whether those prev_events are 1782 # backward extremities, also known as gaps, that need to be 1783 # backfilled. 1784 backward_extremity_query = """ 1785 SELECT 1 FROM event_backward_extremities 1786 WHERE 1787 room_id = ? 1788 AND %s 1789 LIMIT 1 1790 """ 1791 1792 # If the event in question is a backward extremity or has any of its 1793 # prev_events listed as a backward extremity, it's next to a 1794 # backward gap. 1795 clause, args = make_in_list_sql_clause( 1796 self.database_engine, 1797 "event_id", 1798 [event.event_id] + list(event.prev_event_ids()), 1799 ) 1800 1801 txn.execute(backward_extremity_query % (clause,), [event.room_id] + args) 1802 backward_extremities = txn.fetchall() 1803 1804 # We consider any backward extremity as a backward gap 1805 if len(backward_extremities): 1806 return True 1807 1808 return False 1809 1810 return await self.db_pool.runInteraction( 1811 "is_event_next_to_backward_gap_txn", 1812 is_event_next_to_backward_gap_txn, 1813 ) 1814 1815 async def is_event_next_to_forward_gap(self, event: EventBase) -> bool: 1816 """Check if the given event is next to a forward gap of missing events. 1817 The gap in front of the latest events is not considered a gap. 1818 <latest messages> A(False)--->B(False)--->C(False)---> <gap, unknown events> <oldest messages> 1819 <latest messages> A(False)--->B(False)---> <gap, unknown events> --->D(True)--->E(False) <oldest messages> 1820 1821 Args: 1822 room_id: room where the event lives 1823 event_id: event to check 1824 1825 Returns: 1826 Boolean indicating whether it's an extremity 1827 """ 1828 1829 def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool: 1830 # If the event in question is a forward extremity, we will just 1831 # consider any potential forward gap as not a gap since it's one of 1832 # the latest events in the room. 1833 # 1834 # `event_forward_extremities` does not include backfilled or outlier 1835 # events so we can't rely on it to find forward gaps. We can only 1836 # use it to determine whether a message is the latest in the room. 1837 # 1838 # We can't combine this query with the `forward_edge_query` below 1839 # because if the event in question has no forward edges (isn't 1840 # referenced by any other event's prev_events) but is in 1841 # `event_forward_extremities`, we don't want to return 0 rows and 1842 # say it's next to a gap. 1843 forward_extremity_query = """ 1844 SELECT 1 FROM event_forward_extremities 1845 WHERE 1846 room_id = ? 1847 AND event_id = ? 1848 LIMIT 1 1849 """ 1850 1851 # Check to see whether the event in question is already referenced 1852 # by another event. If we don't see any edges, we're next to a 1853 # forward gap. 1854 forward_edge_query = """ 1855 SELECT 1 FROM event_edges 1856 /* Check to make sure the event referencing our event in question is not rejected */ 1857 LEFT JOIN rejections ON event_edges.event_id == rejections.event_id 1858 WHERE 1859 event_edges.room_id = ? 1860 AND event_edges.prev_event_id = ? 1861 /* It's not a valid edge if the event referencing our event in 1862 * question is rejected. 1863 */ 1864 AND rejections.event_id IS NULL 1865 LIMIT 1 1866 """ 1867 1868 # We consider any forward extremity as the latest in the room and 1869 # not a forward gap. 1870 # 1871 # To expand, even though there is technically a gap at the front of 1872 # the room where the forward extremities are, we consider those the 1873 # latest messages in the room so asking other homeservers for more 1874 # is useless. The new latest messages will just be federated as 1875 # usual. 1876 txn.execute(forward_extremity_query, (event.room_id, event.event_id)) 1877 forward_extremities = txn.fetchall() 1878 if len(forward_extremities): 1879 return False 1880 1881 # If there are no forward edges to the event in question (another 1882 # event hasn't referenced this event in their prev_events), then we 1883 # assume there is a forward gap in the history. 1884 txn.execute(forward_edge_query, (event.room_id, event.event_id)) 1885 forward_edges = txn.fetchall() 1886 if not len(forward_edges): 1887 return True 1888 1889 return False 1890 1891 return await self.db_pool.runInteraction( 1892 "is_event_next_to_gap_txn", 1893 is_event_next_to_gap_txn, 1894 ) 1895 1896 async def get_event_id_for_timestamp( 1897 self, room_id: str, timestamp: int, direction: str 1898 ) -> Optional[str]: 1899 """Find the closest event to the given timestamp in the given direction. 1900 1901 Args: 1902 room_id: Room to fetch the event from 1903 timestamp: The point in time (inclusive) we should navigate from in 1904 the given direction to find the closest event. 1905 direction: ["f"|"b"] to indicate whether we should navigate forward 1906 or backward from the given timestamp to find the closest event. 1907 1908 Returns: 1909 The closest event_id otherwise None if we can't find any event in 1910 the given direction. 1911 """ 1912 1913 sql_template = """ 1914 SELECT event_id FROM events 1915 LEFT JOIN rejections USING (event_id) 1916 WHERE 1917 origin_server_ts %s ? 1918 AND room_id = ? 1919 /* Make sure event is not rejected */ 1920 AND rejections.event_id IS NULL 1921 ORDER BY origin_server_ts %s 1922 LIMIT 1; 1923 """ 1924 1925 def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]: 1926 if direction == "b": 1927 # Find closest event *before* a given timestamp. We use descending 1928 # (which gives values largest to smallest) because we want the 1929 # largest possible timestamp *before* the given timestamp. 1930 comparison_operator = "<=" 1931 order = "DESC" 1932 else: 1933 # Find closest event *after* a given timestamp. We use ascending 1934 # (which gives values smallest to largest) because we want the 1935 # closest possible timestamp *after* the given timestamp. 1936 comparison_operator = ">=" 1937 order = "ASC" 1938 1939 txn.execute( 1940 sql_template % (comparison_operator, order), (timestamp, room_id) 1941 ) 1942 row = txn.fetchone() 1943 if row: 1944 (event_id,) = row 1945 return event_id 1946 1947 return None 1948 1949 if direction not in ("f", "b"): 1950 raise ValueError("Unknown direction: %s" % (direction,)) 1951 1952 return await self.db_pool.runInteraction( 1953 "get_event_id_for_timestamp_txn", 1954 get_event_id_for_timestamp_txn, 1955 ) 1956