1# Copyright 2017 Vector Creations Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""A replication client for use by synapse workers.
15"""
16import logging
17from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
18
19from twisted.internet.defer import Deferred
20from twisted.internet.protocol import ReconnectingClientFactory
21
22from synapse.api.constants import EventTypes
23from synapse.federation import send_queue
24from synapse.federation.sender import FederationSender
25from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
26from synapse.metrics.background_process_metrics import run_as_background_process
27from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
28from synapse.replication.tcp.streams import (
29    AccountDataStream,
30    DeviceListsStream,
31    GroupServerStream,
32    PushersStream,
33    PushRulesStream,
34    ReceiptsStream,
35    TagAccountDataStream,
36    ToDeviceStream,
37    TypingStream,
38)
39from synapse.replication.tcp.streams.events import (
40    EventsStream,
41    EventsStreamEventRow,
42    EventsStreamRow,
43)
44from synapse.types import PersistedEventPosition, ReadReceipt, UserID
45from synapse.util.async_helpers import Linearizer, timeout_deferred
46from synapse.util.metrics import Measure
47
48if TYPE_CHECKING:
49    from synapse.replication.tcp.handler import ReplicationCommandHandler
50    from synapse.server import HomeServer
51
52logger = logging.getLogger(__name__)
53
54# How long we allow callers to wait for replication updates before timing out.
55_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30
56
57
58class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
59    """Factory for building connections to the master. Will reconnect if the
60    connection is lost.
61
62    Accepts a handler that is passed to `ClientReplicationStreamProtocol`.
63    """
64
65    initialDelay = 0.1
66    maxDelay = 1  # Try at least once every N seconds
67
68    def __init__(
69        self,
70        hs: "HomeServer",
71        client_name: str,
72        command_handler: "ReplicationCommandHandler",
73    ):
74        self.client_name = client_name
75        self.command_handler = command_handler
76        self.server_name = hs.config.server.server_name
77        self.hs = hs
78        self._clock = hs.get_clock()  # As self.clock is defined in super class
79
80        hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
81
82    def startedConnecting(self, connector):
83        logger.info("Connecting to replication: %r", connector.getDestination())
84
85    def buildProtocol(self, addr):
86        logger.info("Connected to replication: %r", addr)
87        return ClientReplicationStreamProtocol(
88            self.hs,
89            self.client_name,
90            self.server_name,
91            self._clock,
92            self.command_handler,
93        )
94
95    def clientConnectionLost(self, connector, reason):
96        logger.error("Lost replication conn: %r", reason)
97        ReconnectingClientFactory.clientConnectionLost(self, connector, reason)
98
99    def clientConnectionFailed(self, connector, reason):
100        logger.error("Failed to connect to replication: %r", reason)
101        ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
102
103
104class ReplicationDataHandler:
105    """Handles incoming stream updates from replication.
106
107    This instance notifies the slave data store about updates. Can be subclassed
108    to handle updates in additional ways.
109    """
110
111    def __init__(self, hs: "HomeServer"):
112        self.store = hs.get_datastore()
113        self.notifier = hs.get_notifier()
114        self._reactor = hs.get_reactor()
115        self._clock = hs.get_clock()
116        self._streams = hs.get_replication_streams()
117        self._instance_name = hs.get_instance_name()
118        self._typing_handler = hs.get_typing_handler()
119
120        self._notify_pushers = hs.config.worker.start_pushers
121        self._pusher_pool = hs.get_pusherpool()
122        self._presence_handler = hs.get_presence_handler()
123
124        self.send_handler: Optional[FederationSenderHandler] = None
125        if hs.should_send_federation():
126            self.send_handler = FederationSenderHandler(hs)
127
128        # Map from stream to list of deferreds waiting for the stream to
129        # arrive at a particular position. The lists are sorted by stream position.
130        self._streams_to_waiters: Dict[str, List[Tuple[int, Deferred]]] = {}
131
132    async def on_rdata(
133        self, stream_name: str, instance_name: str, token: int, rows: list
134    ):
135        """Called to handle a batch of replication data with a given stream token.
136
137        By default this just pokes the slave store. Can be overridden in subclasses to
138        handle more.
139
140        Args:
141            stream_name: name of the replication stream for this batch of rows
142            instance_name: the instance that wrote the rows.
143            token: stream token for this batch of rows
144            rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
145        """
146        self.store.process_replication_rows(stream_name, instance_name, token, rows)
147
148        if self.send_handler:
149            await self.send_handler.process_replication_rows(stream_name, token, rows)
150
151        if stream_name == TypingStream.NAME:
152            self._typing_handler.process_replication_rows(token, rows)
153            self.notifier.on_new_event(
154                "typing_key", token, rooms=[row.room_id for row in rows]
155            )
156        elif stream_name == PushRulesStream.NAME:
157            self.notifier.on_new_event(
158                "push_rules_key", token, users=[row.user_id for row in rows]
159            )
160        elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME):
161            self.notifier.on_new_event(
162                "account_data_key", token, users=[row.user_id for row in rows]
163            )
164        elif stream_name == ReceiptsStream.NAME:
165            self.notifier.on_new_event(
166                "receipt_key", token, rooms=[row.room_id for row in rows]
167            )
168            await self._pusher_pool.on_new_receipts(
169                token, token, {row.room_id for row in rows}
170            )
171        elif stream_name == ToDeviceStream.NAME:
172            entities = [row.entity for row in rows if row.entity.startswith("@")]
173            if entities:
174                self.notifier.on_new_event("to_device_key", token, users=entities)
175        elif stream_name == DeviceListsStream.NAME:
176            all_room_ids: Set[str] = set()
177            for row in rows:
178                if row.entity.startswith("@"):
179                    room_ids = await self.store.get_rooms_for_user(row.entity)
180                    all_room_ids.update(room_ids)
181            self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
182        elif stream_name == GroupServerStream.NAME:
183            self.notifier.on_new_event(
184                "groups_key", token, users=[row.user_id for row in rows]
185            )
186        elif stream_name == PushersStream.NAME:
187            for row in rows:
188                if row.deleted:
189                    self.stop_pusher(row.user_id, row.app_id, row.pushkey)
190                else:
191                    await self.start_pusher(row.user_id, row.app_id, row.pushkey)
192        elif stream_name == EventsStream.NAME:
193            # We shouldn't get multiple rows per token for events stream, so
194            # we don't need to optimise this for multiple rows.
195            for row in rows:
196                if row.type != EventsStreamEventRow.TypeId:
197                    continue
198                assert isinstance(row, EventsStreamRow)
199                assert isinstance(row.data, EventsStreamEventRow)
200
201                if row.data.rejected:
202                    continue
203
204                extra_users: Tuple[UserID, ...] = ()
205                if row.data.type == EventTypes.Member and row.data.state_key:
206                    extra_users = (UserID.from_string(row.data.state_key),)
207
208                max_token = self.store.get_room_max_token()
209                event_pos = PersistedEventPosition(instance_name, token)
210                await self.notifier.on_new_room_event_args(
211                    event_pos=event_pos,
212                    max_room_stream_token=max_token,
213                    extra_users=extra_users,
214                    room_id=row.data.room_id,
215                    event_id=row.data.event_id,
216                    event_type=row.data.type,
217                    state_key=row.data.state_key,
218                    membership=row.data.membership,
219                )
220
221        await self._presence_handler.process_replication_rows(
222            stream_name, instance_name, token, rows
223        )
224
225        # Notify any waiting deferreds. The list is ordered by position so we
226        # just iterate through the list until we reach a position that is
227        # greater than the received row position.
228        waiting_list = self._streams_to_waiters.get(stream_name, [])
229
230        # Index of first item with a position after the current token, i.e we
231        # have called all deferreds before this index. If not overwritten by
232        # loop below means either a) no items in list so no-op or b) all items
233        # in list were called and so the list should be cleared. Setting it to
234        # `len(list)` works for both cases.
235        index_of_first_deferred_not_called = len(waiting_list)
236
237        for idx, (position, deferred) in enumerate(waiting_list):
238            if position <= token:
239                try:
240                    with PreserveLoggingContext():
241                        deferred.callback(None)
242                except Exception:
243                    # The deferred has been cancelled or timed out.
244                    pass
245            else:
246                # The list is sorted by position so we don't need to continue
247                # checking any further entries in the list.
248                index_of_first_deferred_not_called = idx
249                break
250
251        # Drop all entries in the waiting list that were called in the above
252        # loop. (This maintains the order so no need to resort)
253        waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
254
255    async def on_position(self, stream_name: str, instance_name: str, token: int):
256        await self.on_rdata(stream_name, instance_name, token, [])
257
258        # We poke the generic "replication" notifier to wake anything up that
259        # may be streaming.
260        self.notifier.notify_replication()
261
262    def on_remote_server_up(self, server: str):
263        """Called when get a new REMOTE_SERVER_UP command."""
264
265        # Let's wake up the transaction queue for the server in case we have
266        # pending stuff to send to it.
267        if self.send_handler:
268            self.send_handler.wake_destination(server)
269
270    async def wait_for_stream_position(
271        self, instance_name: str, stream_name: str, position: int
272    ):
273        """Wait until this instance has received updates up to and including
274        the given stream position.
275        """
276
277        if instance_name == self._instance_name:
278            # We don't get told about updates written by this process, and
279            # anyway in that case we don't need to wait.
280            return
281
282        current_position = self._streams[stream_name].current_token(self._instance_name)
283        if position <= current_position:
284            # We're already past the position
285            return
286
287        # Create a new deferred that times out after N seconds, as we don't want
288        # to wedge here forever.
289        deferred: "Deferred[None]" = Deferred()
290        deferred = timeout_deferred(
291            deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor
292        )
293
294        waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
295
296        waiting_list.append((position, deferred))
297        waiting_list.sort(key=lambda t: t[0])
298
299        # We measure here to get in flight counts and average waiting time.
300        with Measure(self._clock, "repl.wait_for_stream_position"):
301            logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
302            await make_deferred_yieldable(deferred)
303            logger.info(
304                "Finished waiting for repl stream %r to reach %s", stream_name, position
305            )
306
307    def stop_pusher(self, user_id, app_id, pushkey):
308        if not self._notify_pushers:
309            return
310
311        key = "%s:%s" % (app_id, pushkey)
312        pushers_for_user = self._pusher_pool.pushers.get(user_id, {})
313        pusher = pushers_for_user.pop(key, None)
314        if pusher is None:
315            return
316        logger.info("Stopping pusher %r / %r", user_id, key)
317        pusher.on_stop()
318
319    async def start_pusher(self, user_id, app_id, pushkey):
320        if not self._notify_pushers:
321            return
322
323        key = "%s:%s" % (app_id, pushkey)
324        logger.info("Starting pusher %r / %r", user_id, key)
325        return await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
326
327
328class FederationSenderHandler:
329    """Processes the fedration replication stream
330
331    This class is only instantiate on the worker responsible for sending outbound
332    federation transactions. It receives rows from the replication stream and forwards
333    the appropriate entries to the FederationSender class.
334    """
335
336    def __init__(self, hs: "HomeServer"):
337        assert hs.should_send_federation()
338
339        self.store = hs.get_datastore()
340        self._is_mine_id = hs.is_mine_id
341        self._hs = hs
342
343        # We need to make a temporary value to ensure that mypy picks up the
344        # right type. We know we should have a federation sender instance since
345        # `should_send_federation` is True.
346        sender = hs.get_federation_sender()
347        assert isinstance(sender, FederationSender)
348        self.federation_sender = sender
349
350        # Stores the latest position in the federation stream we've gotten up
351        # to. This is always set before we use it.
352        self.federation_position: Optional[int] = None
353
354        self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
355
356    def wake_destination(self, server: str):
357        self.federation_sender.wake_destination(server)
358
359    async def process_replication_rows(self, stream_name, token, rows):
360        # The federation stream contains things that we want to send out, e.g.
361        # presence, typing, etc.
362        if stream_name == "federation":
363            send_queue.process_rows_for_federation(self.federation_sender, rows)
364            await self.update_token(token)
365
366        # ... and when new receipts happen
367        elif stream_name == ReceiptsStream.NAME:
368            await self._on_new_receipts(rows)
369
370        # ... as well as device updates and messages
371        elif stream_name == DeviceListsStream.NAME:
372            # The entities are either user IDs (starting with '@') whose devices
373            # have changed, or remote servers that we need to tell about
374            # changes.
375            hosts = {row.entity for row in rows if not row.entity.startswith("@")}
376            for host in hosts:
377                self.federation_sender.send_device_messages(host)
378
379        elif stream_name == ToDeviceStream.NAME:
380            # The to_device stream includes stuff to be pushed to both local
381            # clients and remote servers, so we ignore entities that start with
382            # '@' (since they'll be local users rather than destinations).
383            hosts = {row.entity for row in rows if not row.entity.startswith("@")}
384            for host in hosts:
385                self.federation_sender.send_device_messages(host)
386
387    async def _on_new_receipts(self, rows):
388        """
389        Args:
390            rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]):
391                new receipts to be processed
392        """
393        for receipt in rows:
394            # we only want to send on receipts for our own users
395            if not self._is_mine_id(receipt.user_id):
396                continue
397            if (
398                receipt.data.get("hidden", False)
399                and self._hs.config.experimental.msc2285_enabled
400            ):
401                continue
402            receipt_info = ReadReceipt(
403                receipt.room_id,
404                receipt.receipt_type,
405                receipt.user_id,
406                [receipt.event_id],
407                receipt.data,
408            )
409            await self.federation_sender.send_read_receipt(receipt_info)
410
411    async def update_token(self, token):
412        """Update the record of where we have processed to in the federation stream.
413
414        Called after we have processed a an update received over replication. Sends
415        a FEDERATION_ACK back to the master, and stores the token that we have processed
416         in `federation_stream_position` so that we can restart where we left off.
417        """
418        self.federation_position = token
419
420        # We save and send the ACK to master asynchronously, so we don't block
421        # processing on persistence. We don't need to do this operation for
422        # every single RDATA we receive, we just need to do it periodically.
423
424        if self._fed_position_linearizer.is_queued(None):
425            # There is already a task queued up to save and send the token, so
426            # no need to queue up another task.
427            return
428
429        run_as_background_process("_save_and_send_ack", self._save_and_send_ack)
430
431    async def _save_and_send_ack(self):
432        """Save the current federation position in the database and send an ACK
433        to master with where we're up to.
434        """
435        # We should only be calling this once we've got a token.
436        assert self.federation_position is not None
437
438        try:
439            # We linearize here to ensure we don't have races updating the token
440            #
441            # XXX this appears to be redundant, since the ReplicationCommandHandler
442            # has a linearizer which ensures that we only process one line of
443            # replication data at a time. Should we remove it, or is it doing useful
444            # service for robustness? Or could we replace it with an assertion that
445            # we're not being re-entered?
446
447            with (await self._fed_position_linearizer.queue(None)):
448                # We persist and ack the same position, so we take a copy of it
449                # here as otherwise it can get modified from underneath us.
450                current_position = self.federation_position
451
452                await self.store.update_federation_out_pos(
453                    "federation", current_position
454                )
455
456                # We ACK this token over replication so that the master can drop
457                # its in memory queues
458                self._hs.get_tcp_replication().send_federation_ack(current_position)
459        except Exception:
460            logger.exception("Error updating federation stream position")
461