1# Copyright 2017 Vector Creations Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""A replication client for use by synapse workers. 15""" 16import logging 17from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple 18 19from twisted.internet.defer import Deferred 20from twisted.internet.protocol import ReconnectingClientFactory 21 22from synapse.api.constants import EventTypes 23from synapse.federation import send_queue 24from synapse.federation.sender import FederationSender 25from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable 26from synapse.metrics.background_process_metrics import run_as_background_process 27from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol 28from synapse.replication.tcp.streams import ( 29 AccountDataStream, 30 DeviceListsStream, 31 GroupServerStream, 32 PushersStream, 33 PushRulesStream, 34 ReceiptsStream, 35 TagAccountDataStream, 36 ToDeviceStream, 37 TypingStream, 38) 39from synapse.replication.tcp.streams.events import ( 40 EventsStream, 41 EventsStreamEventRow, 42 EventsStreamRow, 43) 44from synapse.types import PersistedEventPosition, ReadReceipt, UserID 45from synapse.util.async_helpers import Linearizer, timeout_deferred 46from synapse.util.metrics import Measure 47 48if TYPE_CHECKING: 49 from synapse.replication.tcp.handler import ReplicationCommandHandler 50 from synapse.server import HomeServer 51 52logger = logging.getLogger(__name__) 53 54# How long we allow callers to wait for replication updates before timing out. 55_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30 56 57 58class DirectTcpReplicationClientFactory(ReconnectingClientFactory): 59 """Factory for building connections to the master. Will reconnect if the 60 connection is lost. 61 62 Accepts a handler that is passed to `ClientReplicationStreamProtocol`. 63 """ 64 65 initialDelay = 0.1 66 maxDelay = 1 # Try at least once every N seconds 67 68 def __init__( 69 self, 70 hs: "HomeServer", 71 client_name: str, 72 command_handler: "ReplicationCommandHandler", 73 ): 74 self.client_name = client_name 75 self.command_handler = command_handler 76 self.server_name = hs.config.server.server_name 77 self.hs = hs 78 self._clock = hs.get_clock() # As self.clock is defined in super class 79 80 hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying) 81 82 def startedConnecting(self, connector): 83 logger.info("Connecting to replication: %r", connector.getDestination()) 84 85 def buildProtocol(self, addr): 86 logger.info("Connected to replication: %r", addr) 87 return ClientReplicationStreamProtocol( 88 self.hs, 89 self.client_name, 90 self.server_name, 91 self._clock, 92 self.command_handler, 93 ) 94 95 def clientConnectionLost(self, connector, reason): 96 logger.error("Lost replication conn: %r", reason) 97 ReconnectingClientFactory.clientConnectionLost(self, connector, reason) 98 99 def clientConnectionFailed(self, connector, reason): 100 logger.error("Failed to connect to replication: %r", reason) 101 ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) 102 103 104class ReplicationDataHandler: 105 """Handles incoming stream updates from replication. 106 107 This instance notifies the slave data store about updates. Can be subclassed 108 to handle updates in additional ways. 109 """ 110 111 def __init__(self, hs: "HomeServer"): 112 self.store = hs.get_datastore() 113 self.notifier = hs.get_notifier() 114 self._reactor = hs.get_reactor() 115 self._clock = hs.get_clock() 116 self._streams = hs.get_replication_streams() 117 self._instance_name = hs.get_instance_name() 118 self._typing_handler = hs.get_typing_handler() 119 120 self._notify_pushers = hs.config.worker.start_pushers 121 self._pusher_pool = hs.get_pusherpool() 122 self._presence_handler = hs.get_presence_handler() 123 124 self.send_handler: Optional[FederationSenderHandler] = None 125 if hs.should_send_federation(): 126 self.send_handler = FederationSenderHandler(hs) 127 128 # Map from stream to list of deferreds waiting for the stream to 129 # arrive at a particular position. The lists are sorted by stream position. 130 self._streams_to_waiters: Dict[str, List[Tuple[int, Deferred]]] = {} 131 132 async def on_rdata( 133 self, stream_name: str, instance_name: str, token: int, rows: list 134 ): 135 """Called to handle a batch of replication data with a given stream token. 136 137 By default this just pokes the slave store. Can be overridden in subclasses to 138 handle more. 139 140 Args: 141 stream_name: name of the replication stream for this batch of rows 142 instance_name: the instance that wrote the rows. 143 token: stream token for this batch of rows 144 rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. 145 """ 146 self.store.process_replication_rows(stream_name, instance_name, token, rows) 147 148 if self.send_handler: 149 await self.send_handler.process_replication_rows(stream_name, token, rows) 150 151 if stream_name == TypingStream.NAME: 152 self._typing_handler.process_replication_rows(token, rows) 153 self.notifier.on_new_event( 154 "typing_key", token, rooms=[row.room_id for row in rows] 155 ) 156 elif stream_name == PushRulesStream.NAME: 157 self.notifier.on_new_event( 158 "push_rules_key", token, users=[row.user_id for row in rows] 159 ) 160 elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME): 161 self.notifier.on_new_event( 162 "account_data_key", token, users=[row.user_id for row in rows] 163 ) 164 elif stream_name == ReceiptsStream.NAME: 165 self.notifier.on_new_event( 166 "receipt_key", token, rooms=[row.room_id for row in rows] 167 ) 168 await self._pusher_pool.on_new_receipts( 169 token, token, {row.room_id for row in rows} 170 ) 171 elif stream_name == ToDeviceStream.NAME: 172 entities = [row.entity for row in rows if row.entity.startswith("@")] 173 if entities: 174 self.notifier.on_new_event("to_device_key", token, users=entities) 175 elif stream_name == DeviceListsStream.NAME: 176 all_room_ids: Set[str] = set() 177 for row in rows: 178 if row.entity.startswith("@"): 179 room_ids = await self.store.get_rooms_for_user(row.entity) 180 all_room_ids.update(room_ids) 181 self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids) 182 elif stream_name == GroupServerStream.NAME: 183 self.notifier.on_new_event( 184 "groups_key", token, users=[row.user_id for row in rows] 185 ) 186 elif stream_name == PushersStream.NAME: 187 for row in rows: 188 if row.deleted: 189 self.stop_pusher(row.user_id, row.app_id, row.pushkey) 190 else: 191 await self.start_pusher(row.user_id, row.app_id, row.pushkey) 192 elif stream_name == EventsStream.NAME: 193 # We shouldn't get multiple rows per token for events stream, so 194 # we don't need to optimise this for multiple rows. 195 for row in rows: 196 if row.type != EventsStreamEventRow.TypeId: 197 continue 198 assert isinstance(row, EventsStreamRow) 199 assert isinstance(row.data, EventsStreamEventRow) 200 201 if row.data.rejected: 202 continue 203 204 extra_users: Tuple[UserID, ...] = () 205 if row.data.type == EventTypes.Member and row.data.state_key: 206 extra_users = (UserID.from_string(row.data.state_key),) 207 208 max_token = self.store.get_room_max_token() 209 event_pos = PersistedEventPosition(instance_name, token) 210 await self.notifier.on_new_room_event_args( 211 event_pos=event_pos, 212 max_room_stream_token=max_token, 213 extra_users=extra_users, 214 room_id=row.data.room_id, 215 event_id=row.data.event_id, 216 event_type=row.data.type, 217 state_key=row.data.state_key, 218 membership=row.data.membership, 219 ) 220 221 await self._presence_handler.process_replication_rows( 222 stream_name, instance_name, token, rows 223 ) 224 225 # Notify any waiting deferreds. The list is ordered by position so we 226 # just iterate through the list until we reach a position that is 227 # greater than the received row position. 228 waiting_list = self._streams_to_waiters.get(stream_name, []) 229 230 # Index of first item with a position after the current token, i.e we 231 # have called all deferreds before this index. If not overwritten by 232 # loop below means either a) no items in list so no-op or b) all items 233 # in list were called and so the list should be cleared. Setting it to 234 # `len(list)` works for both cases. 235 index_of_first_deferred_not_called = len(waiting_list) 236 237 for idx, (position, deferred) in enumerate(waiting_list): 238 if position <= token: 239 try: 240 with PreserveLoggingContext(): 241 deferred.callback(None) 242 except Exception: 243 # The deferred has been cancelled or timed out. 244 pass 245 else: 246 # The list is sorted by position so we don't need to continue 247 # checking any further entries in the list. 248 index_of_first_deferred_not_called = idx 249 break 250 251 # Drop all entries in the waiting list that were called in the above 252 # loop. (This maintains the order so no need to resort) 253 waiting_list[:] = waiting_list[index_of_first_deferred_not_called:] 254 255 async def on_position(self, stream_name: str, instance_name: str, token: int): 256 await self.on_rdata(stream_name, instance_name, token, []) 257 258 # We poke the generic "replication" notifier to wake anything up that 259 # may be streaming. 260 self.notifier.notify_replication() 261 262 def on_remote_server_up(self, server: str): 263 """Called when get a new REMOTE_SERVER_UP command.""" 264 265 # Let's wake up the transaction queue for the server in case we have 266 # pending stuff to send to it. 267 if self.send_handler: 268 self.send_handler.wake_destination(server) 269 270 async def wait_for_stream_position( 271 self, instance_name: str, stream_name: str, position: int 272 ): 273 """Wait until this instance has received updates up to and including 274 the given stream position. 275 """ 276 277 if instance_name == self._instance_name: 278 # We don't get told about updates written by this process, and 279 # anyway in that case we don't need to wait. 280 return 281 282 current_position = self._streams[stream_name].current_token(self._instance_name) 283 if position <= current_position: 284 # We're already past the position 285 return 286 287 # Create a new deferred that times out after N seconds, as we don't want 288 # to wedge here forever. 289 deferred: "Deferred[None]" = Deferred() 290 deferred = timeout_deferred( 291 deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor 292 ) 293 294 waiting_list = self._streams_to_waiters.setdefault(stream_name, []) 295 296 waiting_list.append((position, deferred)) 297 waiting_list.sort(key=lambda t: t[0]) 298 299 # We measure here to get in flight counts and average waiting time. 300 with Measure(self._clock, "repl.wait_for_stream_position"): 301 logger.info("Waiting for repl stream %r to reach %s", stream_name, position) 302 await make_deferred_yieldable(deferred) 303 logger.info( 304 "Finished waiting for repl stream %r to reach %s", stream_name, position 305 ) 306 307 def stop_pusher(self, user_id, app_id, pushkey): 308 if not self._notify_pushers: 309 return 310 311 key = "%s:%s" % (app_id, pushkey) 312 pushers_for_user = self._pusher_pool.pushers.get(user_id, {}) 313 pusher = pushers_for_user.pop(key, None) 314 if pusher is None: 315 return 316 logger.info("Stopping pusher %r / %r", user_id, key) 317 pusher.on_stop() 318 319 async def start_pusher(self, user_id, app_id, pushkey): 320 if not self._notify_pushers: 321 return 322 323 key = "%s:%s" % (app_id, pushkey) 324 logger.info("Starting pusher %r / %r", user_id, key) 325 return await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id) 326 327 328class FederationSenderHandler: 329 """Processes the fedration replication stream 330 331 This class is only instantiate on the worker responsible for sending outbound 332 federation transactions. It receives rows from the replication stream and forwards 333 the appropriate entries to the FederationSender class. 334 """ 335 336 def __init__(self, hs: "HomeServer"): 337 assert hs.should_send_federation() 338 339 self.store = hs.get_datastore() 340 self._is_mine_id = hs.is_mine_id 341 self._hs = hs 342 343 # We need to make a temporary value to ensure that mypy picks up the 344 # right type. We know we should have a federation sender instance since 345 # `should_send_federation` is True. 346 sender = hs.get_federation_sender() 347 assert isinstance(sender, FederationSender) 348 self.federation_sender = sender 349 350 # Stores the latest position in the federation stream we've gotten up 351 # to. This is always set before we use it. 352 self.federation_position: Optional[int] = None 353 354 self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") 355 356 def wake_destination(self, server: str): 357 self.federation_sender.wake_destination(server) 358 359 async def process_replication_rows(self, stream_name, token, rows): 360 # The federation stream contains things that we want to send out, e.g. 361 # presence, typing, etc. 362 if stream_name == "federation": 363 send_queue.process_rows_for_federation(self.federation_sender, rows) 364 await self.update_token(token) 365 366 # ... and when new receipts happen 367 elif stream_name == ReceiptsStream.NAME: 368 await self._on_new_receipts(rows) 369 370 # ... as well as device updates and messages 371 elif stream_name == DeviceListsStream.NAME: 372 # The entities are either user IDs (starting with '@') whose devices 373 # have changed, or remote servers that we need to tell about 374 # changes. 375 hosts = {row.entity for row in rows if not row.entity.startswith("@")} 376 for host in hosts: 377 self.federation_sender.send_device_messages(host) 378 379 elif stream_name == ToDeviceStream.NAME: 380 # The to_device stream includes stuff to be pushed to both local 381 # clients and remote servers, so we ignore entities that start with 382 # '@' (since they'll be local users rather than destinations). 383 hosts = {row.entity for row in rows if not row.entity.startswith("@")} 384 for host in hosts: 385 self.federation_sender.send_device_messages(host) 386 387 async def _on_new_receipts(self, rows): 388 """ 389 Args: 390 rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]): 391 new receipts to be processed 392 """ 393 for receipt in rows: 394 # we only want to send on receipts for our own users 395 if not self._is_mine_id(receipt.user_id): 396 continue 397 if ( 398 receipt.data.get("hidden", False) 399 and self._hs.config.experimental.msc2285_enabled 400 ): 401 continue 402 receipt_info = ReadReceipt( 403 receipt.room_id, 404 receipt.receipt_type, 405 receipt.user_id, 406 [receipt.event_id], 407 receipt.data, 408 ) 409 await self.federation_sender.send_read_receipt(receipt_info) 410 411 async def update_token(self, token): 412 """Update the record of where we have processed to in the federation stream. 413 414 Called after we have processed a an update received over replication. Sends 415 a FEDERATION_ACK back to the master, and stores the token that we have processed 416 in `federation_stream_position` so that we can restart where we left off. 417 """ 418 self.federation_position = token 419 420 # We save and send the ACK to master asynchronously, so we don't block 421 # processing on persistence. We don't need to do this operation for 422 # every single RDATA we receive, we just need to do it periodically. 423 424 if self._fed_position_linearizer.is_queued(None): 425 # There is already a task queued up to save and send the token, so 426 # no need to queue up another task. 427 return 428 429 run_as_background_process("_save_and_send_ack", self._save_and_send_ack) 430 431 async def _save_and_send_ack(self): 432 """Save the current federation position in the database and send an ACK 433 to master with where we're up to. 434 """ 435 # We should only be calling this once we've got a token. 436 assert self.federation_position is not None 437 438 try: 439 # We linearize here to ensure we don't have races updating the token 440 # 441 # XXX this appears to be redundant, since the ReplicationCommandHandler 442 # has a linearizer which ensures that we only process one line of 443 # replication data at a time. Should we remove it, or is it doing useful 444 # service for robustness? Or could we replace it with an assertion that 445 # we're not being re-entered? 446 447 with (await self._fed_position_linearizer.queue(None)): 448 # We persist and ack the same position, so we take a copy of it 449 # here as otherwise it can get modified from underneath us. 450 current_position = self.federation_position 451 452 await self.store.update_federation_out_pos( 453 "federation", current_position 454 ) 455 456 # We ACK this token over replication so that the master can drop 457 # its in memory queues 458 self._hs.get_tcp_replication().send_federation_ack(current_position) 459 except Exception: 460 logger.exception("Error updating federation stream position") 461