1# Copyright 2017 Vector Creations Ltd
2# Copyright 2020 The Matrix.org Foundation C.I.C.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15import logging
16from typing import (
17    TYPE_CHECKING,
18    Any,
19    Awaitable,
20    Dict,
21    Iterable,
22    Iterator,
23    List,
24    Optional,
25    Set,
26    Tuple,
27    TypeVar,
28    Union,
29)
30
31from prometheus_client import Counter
32from typing_extensions import Deque
33
34from twisted.internet.protocol import ReconnectingClientFactory
35
36from synapse.metrics import LaterGauge
37from synapse.metrics.background_process_metrics import run_as_background_process
38from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
39from synapse.replication.tcp.commands import (
40    ClearUserSyncsCommand,
41    Command,
42    FederationAckCommand,
43    PositionCommand,
44    RdataCommand,
45    RemoteServerUpCommand,
46    ReplicateCommand,
47    UserIpCommand,
48    UserSyncCommand,
49)
50from synapse.replication.tcp.protocol import IReplicationConnection
51from synapse.replication.tcp.streams import (
52    STREAMS_MAP,
53    AccountDataStream,
54    BackfillStream,
55    CachesStream,
56    EventsStream,
57    FederationStream,
58    PresenceFederationStream,
59    PresenceStream,
60    ReceiptsStream,
61    Stream,
62    TagAccountDataStream,
63    ToDeviceStream,
64    TypingStream,
65)
66
67if TYPE_CHECKING:
68    from synapse.server import HomeServer
69
70logger = logging.getLogger(__name__)
71
72
73# number of updates received for each RDATA stream
74inbound_rdata_count = Counter(
75    "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
76)
77user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
78federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
79remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
80
81user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
82
83
84# the type of the entries in _command_queues_by_stream
85_StreamCommandQueue = Deque[
86    Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection]
87]
88
89
90class ReplicationCommandHandler:
91    """Handles incoming commands from replication as well as sending commands
92    back out to connections.
93    """
94
95    def __init__(self, hs: "HomeServer"):
96        self._replication_data_handler = hs.get_replication_data_handler()
97        self._presence_handler = hs.get_presence_handler()
98        self._store = hs.get_datastore()
99        self._notifier = hs.get_notifier()
100        self._clock = hs.get_clock()
101        self._instance_id = hs.get_instance_id()
102        self._instance_name = hs.get_instance_name()
103
104        self._is_presence_writer = (
105            hs.get_instance_name() in hs.config.worker.writers.presence
106        )
107
108        self._streams: Dict[str, Stream] = {
109            stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
110        }
111
112        # List of streams that this instance is the source of
113        self._streams_to_replicate: List[Stream] = []
114
115        for stream in self._streams.values():
116            if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME:
117                # All workers can write to the cache invalidation stream when
118                # using redis.
119                self._streams_to_replicate.append(stream)
120                continue
121
122            if isinstance(stream, (EventsStream, BackfillStream)):
123                # Only add EventStream and BackfillStream as a source on the
124                # instance in charge of event persistence.
125                if hs.get_instance_name() in hs.config.worker.writers.events:
126                    self._streams_to_replicate.append(stream)
127
128                continue
129
130            if isinstance(stream, ToDeviceStream):
131                # Only add ToDeviceStream as a source on instances in charge of
132                # sending to device messages.
133                if hs.get_instance_name() in hs.config.worker.writers.to_device:
134                    self._streams_to_replicate.append(stream)
135
136                continue
137
138            if isinstance(stream, TypingStream):
139                # Only add TypingStream as a source on the instance in charge of
140                # typing.
141                if hs.get_instance_name() in hs.config.worker.writers.typing:
142                    self._streams_to_replicate.append(stream)
143
144                continue
145
146            if isinstance(stream, (AccountDataStream, TagAccountDataStream)):
147                # Only add AccountDataStream and TagAccountDataStream as a source on the
148                # instance in charge of account_data persistence.
149                if hs.get_instance_name() in hs.config.worker.writers.account_data:
150                    self._streams_to_replicate.append(stream)
151
152                continue
153
154            if isinstance(stream, ReceiptsStream):
155                # Only add ReceiptsStream as a source on the instance in charge of
156                # receipts.
157                if hs.get_instance_name() in hs.config.worker.writers.receipts:
158                    self._streams_to_replicate.append(stream)
159
160                continue
161
162            if isinstance(stream, (PresenceStream, PresenceFederationStream)):
163                # Only add PresenceStream as a source on the instance in charge
164                # of presence.
165                if self._is_presence_writer:
166                    self._streams_to_replicate.append(stream)
167
168                continue
169
170            # Only add any other streams if we're on master.
171            if hs.config.worker.worker_app is not None:
172                continue
173
174            if (
175                stream.NAME == FederationStream.NAME
176                and hs.config.worker.send_federation
177            ):
178                # We only support federation stream if federation sending
179                # has been disabled on the master.
180                continue
181
182            self._streams_to_replicate.append(stream)
183
184        # Map of stream name to batched updates. See RdataCommand for info on
185        # how batching works.
186        self._pending_batches: Dict[str, List[Any]] = {}
187
188        # The factory used to create connections.
189        self._factory: Optional[ReconnectingClientFactory] = None
190
191        # The currently connected connections. (The list of places we need to send
192        # outgoing replication commands to.)
193        self._connections: List[IReplicationConnection] = []
194
195        LaterGauge(
196            "synapse_replication_tcp_resource_total_connections",
197            "",
198            [],
199            lambda: len(self._connections),
200        )
201
202        # When POSITION or RDATA commands arrive, we stick them in a queue and process
203        # them in order in a separate background process.
204
205        # the streams which are currently being processed by _unsafe_process_queue
206        self._processing_streams: Set[str] = set()
207
208        # for each stream, a queue of commands that are awaiting processing, and the
209        # connection that they arrived on.
210        self._command_queues_by_stream = {
211            stream_name: _StreamCommandQueue() for stream_name in self._streams
212        }
213
214        # For each connection, the incoming stream names that have received a POSITION
215        # from that connection.
216        self._streams_by_connection: Dict[IReplicationConnection, Set[str]] = {}
217
218        LaterGauge(
219            "synapse_replication_tcp_command_queue",
220            "Number of inbound RDATA/POSITION commands queued for processing",
221            ["stream_name"],
222            lambda: {
223                (stream_name,): len(queue)
224                for stream_name, queue in self._command_queues_by_stream.items()
225            },
226        )
227
228        self._is_master = hs.config.worker.worker_app is None
229
230        self._federation_sender = None
231        if self._is_master and not hs.config.worker.send_federation:
232            self._federation_sender = hs.get_federation_sender()
233
234        self._server_notices_sender = None
235        if self._is_master:
236            self._server_notices_sender = hs.get_server_notices_sender()
237
238    def _add_command_to_stream_queue(
239        self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
240    ) -> None:
241        """Queue the given received command for processing
242
243        Adds the given command to the per-stream queue, and processes the queue if
244        necessary
245        """
246        stream_name = cmd.stream_name
247        queue = self._command_queues_by_stream.get(stream_name)
248        if queue is None:
249            logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name)
250            return
251
252        queue.append((cmd, conn))
253
254        # if we're already processing this stream, there's nothing more to do:
255        # the new entry on the queue will get picked up in due course
256        if stream_name in self._processing_streams:
257            return
258
259        # fire off a background process to start processing the queue.
260        run_as_background_process(
261            "process-replication-data", self._unsafe_process_queue, stream_name
262        )
263
264    async def _unsafe_process_queue(self, stream_name: str):
265        """Processes the command queue for the given stream, until it is empty
266
267        Does not check if there is already a thread processing the queue, hence "unsafe"
268        """
269        assert stream_name not in self._processing_streams
270
271        self._processing_streams.add(stream_name)
272        try:
273            queue = self._command_queues_by_stream.get(stream_name)
274            while queue:
275                cmd, conn = queue.popleft()
276                try:
277                    await self._process_command(cmd, conn, stream_name)
278                except Exception:
279                    logger.exception("Failed to handle command %s", cmd)
280        finally:
281            self._processing_streams.discard(stream_name)
282
283    async def _process_command(
284        self,
285        cmd: Union[PositionCommand, RdataCommand],
286        conn: IReplicationConnection,
287        stream_name: str,
288    ) -> None:
289        if isinstance(cmd, PositionCommand):
290            await self._process_position(stream_name, conn, cmd)
291        elif isinstance(cmd, RdataCommand):
292            await self._process_rdata(stream_name, conn, cmd)
293        else:
294            # This shouldn't be possible
295            raise Exception("Unrecognised command %s in stream queue", cmd.NAME)
296
297    def start_replication(self, hs: "HomeServer"):
298        """Helper method to start a replication connection to the remote server
299        using TCP.
300        """
301        if hs.config.redis.redis_enabled:
302            from synapse.replication.tcp.redis import (
303                RedisDirectTcpReplicationClientFactory,
304            )
305
306            # First let's ensure that we have a ReplicationStreamer started.
307            hs.get_replication_streamer()
308
309            # We need two connections to redis, one for the subscription stream and
310            # one to send commands to (as you can't send further redis commands to a
311            # connection after SUBSCRIBE is called).
312
313            # First create the connection for sending commands.
314            outbound_redis_connection = hs.get_outbound_redis_connection()
315
316            # Now create the factory/connection for the subscription stream.
317            self._factory = RedisDirectTcpReplicationClientFactory(
318                hs, outbound_redis_connection
319            )
320            hs.get_reactor().connectTCP(
321                hs.config.redis.redis_host,  # type: ignore[arg-type]
322                hs.config.redis.redis_port,
323                self._factory,
324                timeout=30,
325                bindAddress=None,
326            )
327        else:
328            client_name = hs.get_instance_name()
329            self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
330            host = hs.config.worker.worker_replication_host
331            port = hs.config.worker.worker_replication_port
332            hs.get_reactor().connectTCP(
333                host,  # type: ignore[arg-type]
334                port,
335                self._factory,
336                timeout=30,
337                bindAddress=None,
338            )
339
340    def get_streams(self) -> Dict[str, Stream]:
341        """Get a map from stream name to all streams."""
342        return self._streams
343
344    def get_streams_to_replicate(self) -> List[Stream]:
345        """Get a list of streams that this instances replicates."""
346        return self._streams_to_replicate
347
348    def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
349        self.send_positions_to_connection(conn)
350
351    def send_positions_to_connection(self, conn: IReplicationConnection):
352        """Send current position of all streams this process is source of to
353        the connection.
354        """
355
356        # We respond with current position of all streams this instance
357        # replicates.
358        for stream in self.get_streams_to_replicate():
359            # Note that we use the current token as the prev token here (rather
360            # than stream.last_token), as we can't be sure that there have been
361            # no rows written between last token and the current token (since we
362            # might be racing with the replication sending bg process).
363            current_token = stream.current_token(self._instance_name)
364            self.send_command(
365                PositionCommand(
366                    stream.NAME,
367                    self._instance_name,
368                    current_token,
369                    current_token,
370                )
371            )
372
373    def on_USER_SYNC(
374        self, conn: IReplicationConnection, cmd: UserSyncCommand
375    ) -> Optional[Awaitable[None]]:
376        user_sync_counter.inc()
377
378        if self._is_presence_writer:
379            return self._presence_handler.update_external_syncs_row(
380                cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
381            )
382        else:
383            return None
384
385    def on_CLEAR_USER_SYNC(
386        self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand
387    ) -> Optional[Awaitable[None]]:
388        if self._is_presence_writer:
389            return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
390        else:
391            return None
392
393    def on_FEDERATION_ACK(
394        self, conn: IReplicationConnection, cmd: FederationAckCommand
395    ):
396        federation_ack_counter.inc()
397
398        if self._federation_sender:
399            self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
400
401    def on_USER_IP(
402        self, conn: IReplicationConnection, cmd: UserIpCommand
403    ) -> Optional[Awaitable[None]]:
404        user_ip_cache_counter.inc()
405
406        if self._is_master:
407            return self._handle_user_ip(cmd)
408        else:
409            return None
410
411    async def _handle_user_ip(self, cmd: UserIpCommand):
412        await self._store.insert_client_ip(
413            cmd.user_id,
414            cmd.access_token,
415            cmd.ip,
416            cmd.user_agent,
417            cmd.device_id,
418            cmd.last_seen,
419        )
420
421        assert self._server_notices_sender is not None
422        await self._server_notices_sender.on_user_ip(cmd.user_id)
423
424    def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
425        if cmd.instance_name == self._instance_name:
426            # Ignore RDATA that are just our own echoes
427            return
428
429        stream_name = cmd.stream_name
430        inbound_rdata_count.labels(stream_name).inc()
431
432        # We put the received command into a queue here for two reasons:
433        #   1. so we don't try and concurrently handle multiple rows for the
434        #      same stream, and
435        #   2. so we don't race with getting a POSITION command and fetching
436        #      missing RDATA.
437
438        self._add_command_to_stream_queue(conn, cmd)
439
440    async def _process_rdata(
441        self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand
442    ) -> None:
443        """Process an RDATA command
444
445        Called after the command has been popped off the queue of inbound commands
446        """
447        try:
448            row = STREAMS_MAP[stream_name].parse_row(cmd.row)
449        except Exception as e:
450            raise Exception(
451                "Failed to parse RDATA: %r %r" % (stream_name, cmd.row)
452            ) from e
453
454        # make sure that we've processed a POSITION for this stream *on this
455        # connection*. (A POSITION on another connection is no good, as there
456        # is no guarantee that we have seen all the intermediate updates.)
457        sbc = self._streams_by_connection.get(conn)
458        if not sbc or stream_name not in sbc:
459            # Let's drop the row for now, on the assumption we'll receive a
460            # `POSITION` soon and we'll catch up correctly then.
461            logger.debug(
462                "Discarding RDATA for unconnected stream %s -> %s",
463                stream_name,
464                cmd.token,
465            )
466            return
467
468        if cmd.token is None:
469            # I.e. this is part of a batch of updates for this stream (in
470            # which case batch until we get an update for the stream with a non
471            # None token).
472            self._pending_batches.setdefault(stream_name, []).append(row)
473            return
474
475        # Check if this is the last of a batch of updates
476        rows = self._pending_batches.pop(stream_name, [])
477        rows.append(row)
478
479        stream = self._streams[stream_name]
480
481        # Find where we previously streamed up to.
482        current_token = stream.current_token(cmd.instance_name)
483
484        # Discard this data if this token is earlier than the current
485        # position. Note that streams can be reset (in which case you
486        # expect an earlier token), but that must be preceded by a
487        # POSITION command.
488        if cmd.token <= current_token:
489            logger.debug(
490                "Discarding RDATA from stream %s at position %s before previous position %s",
491                stream_name,
492                cmd.token,
493                current_token,
494            )
495        else:
496            await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
497
498    async def on_rdata(
499        self, stream_name: str, instance_name: str, token: int, rows: list
500    ):
501        """Called to handle a batch of replication data with a given stream token.
502
503        Args:
504            stream_name: name of the replication stream for this batch of rows
505            instance_name: the instance that wrote the rows.
506            token: stream token for this batch of rows
507            rows: a list of Stream.ROW_TYPE objects as returned by
508                Stream.parse_row.
509        """
510        logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token)
511        await self._replication_data_handler.on_rdata(
512            stream_name, instance_name, token, rows
513        )
514
515    def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
516        if cmd.instance_name == self._instance_name:
517            # Ignore POSITION that are just our own echoes
518            return
519
520        logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
521
522        self._add_command_to_stream_queue(conn, cmd)
523
524    async def _process_position(
525        self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand
526    ) -> None:
527        """Process a POSITION command
528
529        Called after the command has been popped off the queue of inbound commands
530        """
531        stream = self._streams[stream_name]
532
533        # We're about to go and catch up with the stream, so remove from set
534        # of connected streams.
535        for streams in self._streams_by_connection.values():
536            streams.discard(stream_name)
537
538        # We clear the pending batches for the stream as the fetching of the
539        # missing updates below will fetch all rows in the batch.
540        self._pending_batches.pop(stream_name, [])
541
542        # Find where we previously streamed up to.
543        current_token = stream.current_token(cmd.instance_name)
544
545        # If the position token matches our current token then we're up to
546        # date and there's nothing to do. Otherwise, fetch all updates
547        # between then and now.
548        missing_updates = cmd.prev_token != current_token
549        while missing_updates:
550            logger.info(
551                "Fetching replication rows for '%s' between %i and %i",
552                stream_name,
553                current_token,
554                cmd.new_token,
555            )
556            (updates, current_token, missing_updates) = await stream.get_updates_since(
557                cmd.instance_name, current_token, cmd.new_token
558            )
559
560            # TODO: add some tests for this
561
562            # Some streams return multiple rows with the same stream IDs,
563            # which need to be processed in batches.
564
565            for token, rows in _batch_updates(updates):
566                await self.on_rdata(
567                    stream_name,
568                    cmd.instance_name,
569                    token,
570                    [stream.parse_row(row) for row in rows],
571                )
572
573        logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token)
574
575        # We've now caught up to position sent to us, notify handler.
576        await self._replication_data_handler.on_position(
577            cmd.stream_name, cmd.instance_name, cmd.new_token
578        )
579
580        self._streams_by_connection.setdefault(conn, set()).add(stream_name)
581
582    def on_REMOTE_SERVER_UP(
583        self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
584    ):
585        """Called when get a new REMOTE_SERVER_UP command."""
586        self._replication_data_handler.on_remote_server_up(cmd.data)
587
588        self._notifier.notify_remote_server_up(cmd.data)
589
590        # We relay to all other connections to ensure every instance gets the
591        # notification.
592        #
593        # When configured to use redis we'll always only have one connection and
594        # so this is a no-op (all instances will have already received the same
595        # REMOTE_SERVER_UP command).
596        #
597        # For direct TCP connections this will relay to all other connections
598        # connected to us. When on master this will correctly fan out to all
599        # other direct TCP clients and on workers there'll only be the one
600        # connection to master.
601        #
602        # (The logic here should also be sound if we have a mix of Redis and
603        # direct TCP connections so long as there is only one traffic route
604        # between two instances, but that is not currently supported).
605        self.send_command(cmd, ignore_conn=conn)
606
607    def new_connection(self, connection: IReplicationConnection):
608        """Called when we have a new connection."""
609        self._connections.append(connection)
610
611        # If we are connected to replication as a client (rather than a server)
612        # we need to reset the reconnection delay on the client factory (which
613        # is used to do exponential back off when the connection drops).
614        #
615        # Ideally we would reset the delay when we've "fully established" the
616        # connection (for some definition thereof) to stop us from tightlooping
617        # on reconnection if something fails after this point and we drop the
618        # connection. Unfortunately, we don't really have a better definition of
619        # "fully established" than the connection being established.
620        if self._factory:
621            self._factory.resetDelay()
622
623        # Tell the other end if we have any users currently syncing.
624        currently_syncing = (
625            self._presence_handler.get_currently_syncing_users_for_replication()
626        )
627
628        now = self._clock.time_msec()
629        for user_id in currently_syncing:
630            connection.send_command(
631                UserSyncCommand(self._instance_id, user_id, True, now)
632            )
633
634    def lost_connection(self, connection: IReplicationConnection):
635        """Called when a connection is closed/lost."""
636        # we no longer need _streams_by_connection for this connection.
637        streams = self._streams_by_connection.pop(connection, None)
638        if streams:
639            logger.info(
640                "Lost replication connection; streams now disconnected: %s", streams
641            )
642        try:
643            self._connections.remove(connection)
644        except ValueError:
645            pass
646
647    def connected(self) -> bool:
648        """Do we have any replication connections open?
649
650        Is used by e.g. `ReplicationStreamer` to no-op if nothing is connected.
651        """
652        return bool(self._connections)
653
654    def send_command(
655        self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
656    ):
657        """Send a command to all connected connections.
658
659        Args:
660            cmd
661            ignore_conn: If set don't send command to the given connection.
662                Used when relaying commands from one connection to all others.
663        """
664        if self._connections:
665            for connection in self._connections:
666                if connection == ignore_conn:
667                    continue
668
669                try:
670                    connection.send_command(cmd)
671                except Exception:
672                    # We probably want to catch some types of exceptions here
673                    # and log them as warnings (e.g. connection gone), but I
674                    # can't find what those exception types they would be.
675                    logger.exception(
676                        "Failed to write command %s to connection %s",
677                        cmd.NAME,
678                        connection,
679                    )
680        else:
681            logger.warning("Dropping command as not connected: %r", cmd.NAME)
682
683    def send_federation_ack(self, token: int):
684        """Ack data for the federation stream. This allows the master to drop
685        data stored purely in memory.
686        """
687        self.send_command(FederationAckCommand(self._instance_name, token))
688
689    def send_user_sync(
690        self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
691    ):
692        """Poke the master that a user has started/stopped syncing."""
693        self.send_command(
694            UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
695        )
696
697    def send_user_ip(
698        self,
699        user_id: str,
700        access_token: str,
701        ip: str,
702        user_agent: str,
703        device_id: str,
704        last_seen: int,
705    ):
706        """Tell the master that the user made a request."""
707        cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
708        self.send_command(cmd)
709
710    def send_remote_server_up(self, server: str):
711        self.send_command(RemoteServerUpCommand(server))
712
713    def stream_update(self, stream_name: str, token: str, data: Any):
714        """Called when a new update is available to stream to clients.
715
716        We need to check if the client is interested in the stream or not
717        """
718        self.send_command(RdataCommand(stream_name, self._instance_name, token, data))
719
720
721UpdateToken = TypeVar("UpdateToken")
722UpdateRow = TypeVar("UpdateRow")
723
724
725def _batch_updates(
726    updates: Iterable[Tuple[UpdateToken, UpdateRow]]
727) -> Iterator[Tuple[UpdateToken, List[UpdateRow]]]:
728    """Collect stream updates with the same token together
729
730    Given a series of updates returned by Stream.get_updates_since(), collects
731    the updates which share the same stream_id together.
732
733    For example:
734
735        [(1, a), (1, b), (2, c), (3, d), (3, e)]
736
737    becomes:
738
739        [
740            (1, [a, b]),
741            (2, [c]),
742            (3, [d, e]),
743        ]
744    """
745
746    update_iter = iter(updates)
747
748    first_update = next(update_iter, None)
749    if first_update is None:
750        # empty input
751        return
752
753    current_batch_token = first_update[0]
754    current_batch = [first_update[1]]
755
756    for token, row in update_iter:
757        if token != current_batch_token:
758            # different token to the previous row: flush the previous
759            # batch and start anew
760            yield current_batch_token, current_batch
761            current_batch_token = token
762            current_batch = []
763
764        current_batch.append(row)
765
766    # flush the final batch
767    yield current_batch_token, current_batch
768