1# Copyright 2017 Vector Creations Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""The server side of the replication stream.
15"""
16
17import logging
18import random
19from typing import TYPE_CHECKING
20
21from prometheus_client import Counter
22
23from twisted.internet.protocol import ServerFactory
24
25from synapse.metrics.background_process_metrics import run_as_background_process
26from synapse.replication.tcp.commands import PositionCommand
27from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
28from synapse.replication.tcp.streams import EventsStream
29from synapse.util.metrics import Measure
30
31if TYPE_CHECKING:
32    from synapse.server import HomeServer
33
34stream_updates_counter = Counter(
35    "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
36)
37
38logger = logging.getLogger(__name__)
39
40
41class ReplicationStreamProtocolFactory(ServerFactory):
42    """Factory for new replication connections."""
43
44    def __init__(self, hs: "HomeServer"):
45        self.command_handler = hs.get_tcp_replication()
46        self.clock = hs.get_clock()
47        self.server_name = hs.config.server.server_name
48
49        # If we've created a `ReplicationStreamProtocolFactory` then we're
50        # almost certainly registering a replication listener, so let's ensure
51        # that we've started a `ReplicationStreamer` instance to actually push
52        # data.
53        #
54        # (This is a bit of a weird place to do this, but the alternatives such
55        # as putting this in `HomeServer.setup()`, requires either passing the
56        # listener config again or always starting a `ReplicationStreamer`.)
57        hs.get_replication_streamer()
58
59    def buildProtocol(self, addr):
60        return ServerReplicationStreamProtocol(
61            self.server_name, self.clock, self.command_handler
62        )
63
64
65class ReplicationStreamer:
66    """Handles replication connections.
67
68    This needs to be poked when new replication data may be available. When new
69    data is available it will propagate to all connected clients.
70    """
71
72    def __init__(self, hs: "HomeServer"):
73        self.store = hs.get_datastore()
74        self.clock = hs.get_clock()
75        self.notifier = hs.get_notifier()
76        self._instance_name = hs.get_instance_name()
77
78        self._replication_torture_level = hs.config.server.replication_torture_level
79
80        self.notifier.add_replication_callback(self.on_notifier_poke)
81
82        # Keeps track of whether we are currently checking for updates
83        self.is_looping = False
84        self.pending_updates = False
85
86        self.command_handler = hs.get_tcp_replication()
87
88        # Set of streams to replicate.
89        self.streams = self.command_handler.get_streams_to_replicate()
90
91        # If we have streams then we must have redis enabled or on master
92        assert (
93            not self.streams
94            or hs.config.redis.redis_enabled
95            or not hs.config.worker.worker_app
96        )
97
98        # If we are replicating an event stream we want to periodically check if
99        # we should send updated POSITIONs. We do this as a looping call rather
100        # explicitly poking when the position advances (without new data to
101        # replicate) to reduce replication traffic (otherwise each writer would
102        # likely send a POSITION for each new event received over replication).
103        #
104        # Note that if the position hasn't advanced then we won't send anything.
105        if any(EventsStream.NAME == s.NAME for s in self.streams):
106            self.clock.looping_call(self.on_notifier_poke, 1000)
107
108    def on_notifier_poke(self):
109        """Checks if there is actually any new data and sends it to the
110        connections if there are.
111
112        This should get called each time new data is available, even if it
113        is currently being executed, so that nothing gets missed
114        """
115        if not self.command_handler.connected() or not self.streams:
116            # Don't bother if nothing is listening. We still need to advance
117            # the stream tokens otherwise they'll fall behind forever
118            for stream in self.streams:
119                stream.discard_updates_and_advance()
120            return
121
122        # We check up front to see if anything has actually changed, as we get
123        # poked because of changes that happened on other instances.
124        if all(
125            stream.last_token == stream.current_token(self._instance_name)
126            for stream in self.streams
127        ):
128            return
129
130        # If there are updates then we need to set this even if we're already
131        # looping, as the loop needs to know that he might need to loop again.
132        self.pending_updates = True
133
134        if self.is_looping:
135            logger.debug("Notifier poke loop already running")
136            return
137
138        run_as_background_process("replication_notifier", self._run_notifier_loop)
139
140    async def _run_notifier_loop(self):
141        self.is_looping = True
142
143        try:
144            # Keep looping while there have been pokes about potential updates.
145            # This protects against the race where a stream we already checked
146            # gets an update while we're handling other streams.
147            while self.pending_updates:
148                self.pending_updates = False
149
150                with Measure(self.clock, "repl.stream.get_updates"):
151                    all_streams = self.streams
152
153                    if self._replication_torture_level is not None:
154                        # there is no guarantee about ordering between the streams,
155                        # so let's shuffle them around a bit when we are in torture mode.
156                        all_streams = list(all_streams)
157                        random.shuffle(all_streams)
158
159                    for stream in all_streams:
160                        if stream.last_token == stream.current_token(
161                            self._instance_name
162                        ):
163                            continue
164
165                        if self._replication_torture_level:
166                            await self.clock.sleep(
167                                self._replication_torture_level / 1000.0
168                            )
169
170                        last_token = stream.last_token
171
172                        logger.debug(
173                            "Getting stream: %s: %s -> %s",
174                            stream.NAME,
175                            stream.last_token,
176                            stream.current_token(self._instance_name),
177                        )
178                        try:
179                            updates, current_token, limited = await stream.get_updates()
180                            self.pending_updates |= limited
181                        except Exception:
182                            logger.info("Failed to handle stream %s", stream.NAME)
183                            raise
184
185                        logger.debug(
186                            "Sending %d updates",
187                            len(updates),
188                        )
189
190                        if updates:
191                            logger.info(
192                                "Streaming: %s -> %s", stream.NAME, updates[-1][0]
193                            )
194                            stream_updates_counter.labels(stream.NAME).inc(len(updates))
195
196                        else:
197                            # The token has advanced but there is no data to
198                            # send, so we send a `POSITION` to inform other
199                            # workers of the updated position.
200                            if stream.NAME == EventsStream.NAME:
201                                # XXX: We only do this for the EventStream as it
202                                # turns out that e.g. account data streams share
203                                # their "current token" with each other, meaning
204                                # that it is *not* safe to send a POSITION.
205                                logger.info(
206                                    "Sending position: %s -> %s",
207                                    stream.NAME,
208                                    current_token,
209                                )
210                                self.command_handler.send_command(
211                                    PositionCommand(
212                                        stream.NAME,
213                                        self._instance_name,
214                                        last_token,
215                                        current_token,
216                                    )
217                                )
218                            continue
219
220                        # Some streams return multiple rows with the same stream IDs,
221                        # we need to make sure they get sent out in batches. We do
222                        # this by setting the current token to all but the last of
223                        # a series of updates with the same token to have a None
224                        # token. See RdataCommand for more details.
225                        batched_updates = _batch_updates(updates)
226
227                        for token, row in batched_updates:
228                            try:
229                                self.command_handler.stream_update(
230                                    stream.NAME, token, row
231                                )
232                            except Exception:
233                                logger.exception("Failed to replicate")
234
235            logger.debug("No more pending updates, breaking poke loop")
236        finally:
237            self.pending_updates = False
238            self.is_looping = False
239
240
241def _batch_updates(updates):
242    """Takes a list of updates of form [(token, row)] and sets the token to
243    None for all rows where the next row has the same token. This is used to
244    implement batching.
245
246    For example:
247
248        [(1, _), (1, _), (2, _), (3, _), (3, _)]
249
250    becomes:
251
252        [(None, _), (1, _), (2, _), (None, _), (3, _)]
253    """
254    if not updates:
255        return []
256
257    new_updates = []
258    for i, update in enumerate(updates[:-1]):
259        if update[0] == updates[i + 1][0]:
260            new_updates.append((None, update[1]))
261        else:
262            new_updates.append(update)
263
264    new_updates.append(updates[-1])
265    return new_updates
266