1# Copyright 2017 Vector Creations Ltd 2# Copyright 2020 The Matrix.org Foundation C.I.C. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15import logging 16from typing import ( 17 TYPE_CHECKING, 18 Any, 19 Awaitable, 20 Dict, 21 Iterable, 22 Iterator, 23 List, 24 Optional, 25 Set, 26 Tuple, 27 TypeVar, 28 Union, 29) 30 31from prometheus_client import Counter 32from typing_extensions import Deque 33 34from twisted.internet.protocol import ReconnectingClientFactory 35 36from synapse.metrics import LaterGauge 37from synapse.metrics.background_process_metrics import run_as_background_process 38from synapse.replication.tcp.client import DirectTcpReplicationClientFactory 39from synapse.replication.tcp.commands import ( 40 ClearUserSyncsCommand, 41 Command, 42 FederationAckCommand, 43 PositionCommand, 44 RdataCommand, 45 RemoteServerUpCommand, 46 ReplicateCommand, 47 UserIpCommand, 48 UserSyncCommand, 49) 50from synapse.replication.tcp.protocol import IReplicationConnection 51from synapse.replication.tcp.streams import ( 52 STREAMS_MAP, 53 AccountDataStream, 54 BackfillStream, 55 CachesStream, 56 EventsStream, 57 FederationStream, 58 PresenceFederationStream, 59 PresenceStream, 60 ReceiptsStream, 61 Stream, 62 TagAccountDataStream, 63 ToDeviceStream, 64 TypingStream, 65) 66 67if TYPE_CHECKING: 68 from synapse.server import HomeServer 69 70logger = logging.getLogger(__name__) 71 72 73# number of updates received for each RDATA stream 74inbound_rdata_count = Counter( 75 "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] 76) 77user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") 78federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") 79remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") 80 81user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") 82 83 84# the type of the entries in _command_queues_by_stream 85_StreamCommandQueue = Deque[ 86 Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection] 87] 88 89 90class ReplicationCommandHandler: 91 """Handles incoming commands from replication as well as sending commands 92 back out to connections. 93 """ 94 95 def __init__(self, hs: "HomeServer"): 96 self._replication_data_handler = hs.get_replication_data_handler() 97 self._presence_handler = hs.get_presence_handler() 98 self._store = hs.get_datastore() 99 self._notifier = hs.get_notifier() 100 self._clock = hs.get_clock() 101 self._instance_id = hs.get_instance_id() 102 self._instance_name = hs.get_instance_name() 103 104 self._is_presence_writer = ( 105 hs.get_instance_name() in hs.config.worker.writers.presence 106 ) 107 108 self._streams: Dict[str, Stream] = { 109 stream.NAME: stream(hs) for stream in STREAMS_MAP.values() 110 } 111 112 # List of streams that this instance is the source of 113 self._streams_to_replicate: List[Stream] = [] 114 115 for stream in self._streams.values(): 116 if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME: 117 # All workers can write to the cache invalidation stream when 118 # using redis. 119 self._streams_to_replicate.append(stream) 120 continue 121 122 if isinstance(stream, (EventsStream, BackfillStream)): 123 # Only add EventStream and BackfillStream as a source on the 124 # instance in charge of event persistence. 125 if hs.get_instance_name() in hs.config.worker.writers.events: 126 self._streams_to_replicate.append(stream) 127 128 continue 129 130 if isinstance(stream, ToDeviceStream): 131 # Only add ToDeviceStream as a source on instances in charge of 132 # sending to device messages. 133 if hs.get_instance_name() in hs.config.worker.writers.to_device: 134 self._streams_to_replicate.append(stream) 135 136 continue 137 138 if isinstance(stream, TypingStream): 139 # Only add TypingStream as a source on the instance in charge of 140 # typing. 141 if hs.get_instance_name() in hs.config.worker.writers.typing: 142 self._streams_to_replicate.append(stream) 143 144 continue 145 146 if isinstance(stream, (AccountDataStream, TagAccountDataStream)): 147 # Only add AccountDataStream and TagAccountDataStream as a source on the 148 # instance in charge of account_data persistence. 149 if hs.get_instance_name() in hs.config.worker.writers.account_data: 150 self._streams_to_replicate.append(stream) 151 152 continue 153 154 if isinstance(stream, ReceiptsStream): 155 # Only add ReceiptsStream as a source on the instance in charge of 156 # receipts. 157 if hs.get_instance_name() in hs.config.worker.writers.receipts: 158 self._streams_to_replicate.append(stream) 159 160 continue 161 162 if isinstance(stream, (PresenceStream, PresenceFederationStream)): 163 # Only add PresenceStream as a source on the instance in charge 164 # of presence. 165 if self._is_presence_writer: 166 self._streams_to_replicate.append(stream) 167 168 continue 169 170 # Only add any other streams if we're on master. 171 if hs.config.worker.worker_app is not None: 172 continue 173 174 if ( 175 stream.NAME == FederationStream.NAME 176 and hs.config.worker.send_federation 177 ): 178 # We only support federation stream if federation sending 179 # has been disabled on the master. 180 continue 181 182 self._streams_to_replicate.append(stream) 183 184 # Map of stream name to batched updates. See RdataCommand for info on 185 # how batching works. 186 self._pending_batches: Dict[str, List[Any]] = {} 187 188 # The factory used to create connections. 189 self._factory: Optional[ReconnectingClientFactory] = None 190 191 # The currently connected connections. (The list of places we need to send 192 # outgoing replication commands to.) 193 self._connections: List[IReplicationConnection] = [] 194 195 LaterGauge( 196 "synapse_replication_tcp_resource_total_connections", 197 "", 198 [], 199 lambda: len(self._connections), 200 ) 201 202 # When POSITION or RDATA commands arrive, we stick them in a queue and process 203 # them in order in a separate background process. 204 205 # the streams which are currently being processed by _unsafe_process_queue 206 self._processing_streams: Set[str] = set() 207 208 # for each stream, a queue of commands that are awaiting processing, and the 209 # connection that they arrived on. 210 self._command_queues_by_stream = { 211 stream_name: _StreamCommandQueue() for stream_name in self._streams 212 } 213 214 # For each connection, the incoming stream names that have received a POSITION 215 # from that connection. 216 self._streams_by_connection: Dict[IReplicationConnection, Set[str]] = {} 217 218 LaterGauge( 219 "synapse_replication_tcp_command_queue", 220 "Number of inbound RDATA/POSITION commands queued for processing", 221 ["stream_name"], 222 lambda: { 223 (stream_name,): len(queue) 224 for stream_name, queue in self._command_queues_by_stream.items() 225 }, 226 ) 227 228 self._is_master = hs.config.worker.worker_app is None 229 230 self._federation_sender = None 231 if self._is_master and not hs.config.worker.send_federation: 232 self._federation_sender = hs.get_federation_sender() 233 234 self._server_notices_sender = None 235 if self._is_master: 236 self._server_notices_sender = hs.get_server_notices_sender() 237 238 def _add_command_to_stream_queue( 239 self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand] 240 ) -> None: 241 """Queue the given received command for processing 242 243 Adds the given command to the per-stream queue, and processes the queue if 244 necessary 245 """ 246 stream_name = cmd.stream_name 247 queue = self._command_queues_by_stream.get(stream_name) 248 if queue is None: 249 logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name) 250 return 251 252 queue.append((cmd, conn)) 253 254 # if we're already processing this stream, there's nothing more to do: 255 # the new entry on the queue will get picked up in due course 256 if stream_name in self._processing_streams: 257 return 258 259 # fire off a background process to start processing the queue. 260 run_as_background_process( 261 "process-replication-data", self._unsafe_process_queue, stream_name 262 ) 263 264 async def _unsafe_process_queue(self, stream_name: str): 265 """Processes the command queue for the given stream, until it is empty 266 267 Does not check if there is already a thread processing the queue, hence "unsafe" 268 """ 269 assert stream_name not in self._processing_streams 270 271 self._processing_streams.add(stream_name) 272 try: 273 queue = self._command_queues_by_stream.get(stream_name) 274 while queue: 275 cmd, conn = queue.popleft() 276 try: 277 await self._process_command(cmd, conn, stream_name) 278 except Exception: 279 logger.exception("Failed to handle command %s", cmd) 280 finally: 281 self._processing_streams.discard(stream_name) 282 283 async def _process_command( 284 self, 285 cmd: Union[PositionCommand, RdataCommand], 286 conn: IReplicationConnection, 287 stream_name: str, 288 ) -> None: 289 if isinstance(cmd, PositionCommand): 290 await self._process_position(stream_name, conn, cmd) 291 elif isinstance(cmd, RdataCommand): 292 await self._process_rdata(stream_name, conn, cmd) 293 else: 294 # This shouldn't be possible 295 raise Exception("Unrecognised command %s in stream queue", cmd.NAME) 296 297 def start_replication(self, hs: "HomeServer"): 298 """Helper method to start a replication connection to the remote server 299 using TCP. 300 """ 301 if hs.config.redis.redis_enabled: 302 from synapse.replication.tcp.redis import ( 303 RedisDirectTcpReplicationClientFactory, 304 ) 305 306 # First let's ensure that we have a ReplicationStreamer started. 307 hs.get_replication_streamer() 308 309 # We need two connections to redis, one for the subscription stream and 310 # one to send commands to (as you can't send further redis commands to a 311 # connection after SUBSCRIBE is called). 312 313 # First create the connection for sending commands. 314 outbound_redis_connection = hs.get_outbound_redis_connection() 315 316 # Now create the factory/connection for the subscription stream. 317 self._factory = RedisDirectTcpReplicationClientFactory( 318 hs, outbound_redis_connection 319 ) 320 hs.get_reactor().connectTCP( 321 hs.config.redis.redis_host, # type: ignore[arg-type] 322 hs.config.redis.redis_port, 323 self._factory, 324 timeout=30, 325 bindAddress=None, 326 ) 327 else: 328 client_name = hs.get_instance_name() 329 self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) 330 host = hs.config.worker.worker_replication_host 331 port = hs.config.worker.worker_replication_port 332 hs.get_reactor().connectTCP( 333 host, # type: ignore[arg-type] 334 port, 335 self._factory, 336 timeout=30, 337 bindAddress=None, 338 ) 339 340 def get_streams(self) -> Dict[str, Stream]: 341 """Get a map from stream name to all streams.""" 342 return self._streams 343 344 def get_streams_to_replicate(self) -> List[Stream]: 345 """Get a list of streams that this instances replicates.""" 346 return self._streams_to_replicate 347 348 def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand): 349 self.send_positions_to_connection(conn) 350 351 def send_positions_to_connection(self, conn: IReplicationConnection): 352 """Send current position of all streams this process is source of to 353 the connection. 354 """ 355 356 # We respond with current position of all streams this instance 357 # replicates. 358 for stream in self.get_streams_to_replicate(): 359 # Note that we use the current token as the prev token here (rather 360 # than stream.last_token), as we can't be sure that there have been 361 # no rows written between last token and the current token (since we 362 # might be racing with the replication sending bg process). 363 current_token = stream.current_token(self._instance_name) 364 self.send_command( 365 PositionCommand( 366 stream.NAME, 367 self._instance_name, 368 current_token, 369 current_token, 370 ) 371 ) 372 373 def on_USER_SYNC( 374 self, conn: IReplicationConnection, cmd: UserSyncCommand 375 ) -> Optional[Awaitable[None]]: 376 user_sync_counter.inc() 377 378 if self._is_presence_writer: 379 return self._presence_handler.update_external_syncs_row( 380 cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms 381 ) 382 else: 383 return None 384 385 def on_CLEAR_USER_SYNC( 386 self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand 387 ) -> Optional[Awaitable[None]]: 388 if self._is_presence_writer: 389 return self._presence_handler.update_external_syncs_clear(cmd.instance_id) 390 else: 391 return None 392 393 def on_FEDERATION_ACK( 394 self, conn: IReplicationConnection, cmd: FederationAckCommand 395 ): 396 federation_ack_counter.inc() 397 398 if self._federation_sender: 399 self._federation_sender.federation_ack(cmd.instance_name, cmd.token) 400 401 def on_USER_IP( 402 self, conn: IReplicationConnection, cmd: UserIpCommand 403 ) -> Optional[Awaitable[None]]: 404 user_ip_cache_counter.inc() 405 406 if self._is_master: 407 return self._handle_user_ip(cmd) 408 else: 409 return None 410 411 async def _handle_user_ip(self, cmd: UserIpCommand): 412 await self._store.insert_client_ip( 413 cmd.user_id, 414 cmd.access_token, 415 cmd.ip, 416 cmd.user_agent, 417 cmd.device_id, 418 cmd.last_seen, 419 ) 420 421 assert self._server_notices_sender is not None 422 await self._server_notices_sender.on_user_ip(cmd.user_id) 423 424 def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand): 425 if cmd.instance_name == self._instance_name: 426 # Ignore RDATA that are just our own echoes 427 return 428 429 stream_name = cmd.stream_name 430 inbound_rdata_count.labels(stream_name).inc() 431 432 # We put the received command into a queue here for two reasons: 433 # 1. so we don't try and concurrently handle multiple rows for the 434 # same stream, and 435 # 2. so we don't race with getting a POSITION command and fetching 436 # missing RDATA. 437 438 self._add_command_to_stream_queue(conn, cmd) 439 440 async def _process_rdata( 441 self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand 442 ) -> None: 443 """Process an RDATA command 444 445 Called after the command has been popped off the queue of inbound commands 446 """ 447 try: 448 row = STREAMS_MAP[stream_name].parse_row(cmd.row) 449 except Exception as e: 450 raise Exception( 451 "Failed to parse RDATA: %r %r" % (stream_name, cmd.row) 452 ) from e 453 454 # make sure that we've processed a POSITION for this stream *on this 455 # connection*. (A POSITION on another connection is no good, as there 456 # is no guarantee that we have seen all the intermediate updates.) 457 sbc = self._streams_by_connection.get(conn) 458 if not sbc or stream_name not in sbc: 459 # Let's drop the row for now, on the assumption we'll receive a 460 # `POSITION` soon and we'll catch up correctly then. 461 logger.debug( 462 "Discarding RDATA for unconnected stream %s -> %s", 463 stream_name, 464 cmd.token, 465 ) 466 return 467 468 if cmd.token is None: 469 # I.e. this is part of a batch of updates for this stream (in 470 # which case batch until we get an update for the stream with a non 471 # None token). 472 self._pending_batches.setdefault(stream_name, []).append(row) 473 return 474 475 # Check if this is the last of a batch of updates 476 rows = self._pending_batches.pop(stream_name, []) 477 rows.append(row) 478 479 stream = self._streams[stream_name] 480 481 # Find where we previously streamed up to. 482 current_token = stream.current_token(cmd.instance_name) 483 484 # Discard this data if this token is earlier than the current 485 # position. Note that streams can be reset (in which case you 486 # expect an earlier token), but that must be preceded by a 487 # POSITION command. 488 if cmd.token <= current_token: 489 logger.debug( 490 "Discarding RDATA from stream %s at position %s before previous position %s", 491 stream_name, 492 cmd.token, 493 current_token, 494 ) 495 else: 496 await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) 497 498 async def on_rdata( 499 self, stream_name: str, instance_name: str, token: int, rows: list 500 ): 501 """Called to handle a batch of replication data with a given stream token. 502 503 Args: 504 stream_name: name of the replication stream for this batch of rows 505 instance_name: the instance that wrote the rows. 506 token: stream token for this batch of rows 507 rows: a list of Stream.ROW_TYPE objects as returned by 508 Stream.parse_row. 509 """ 510 logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token) 511 await self._replication_data_handler.on_rdata( 512 stream_name, instance_name, token, rows 513 ) 514 515 def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand): 516 if cmd.instance_name == self._instance_name: 517 # Ignore POSITION that are just our own echoes 518 return 519 520 logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line()) 521 522 self._add_command_to_stream_queue(conn, cmd) 523 524 async def _process_position( 525 self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand 526 ) -> None: 527 """Process a POSITION command 528 529 Called after the command has been popped off the queue of inbound commands 530 """ 531 stream = self._streams[stream_name] 532 533 # We're about to go and catch up with the stream, so remove from set 534 # of connected streams. 535 for streams in self._streams_by_connection.values(): 536 streams.discard(stream_name) 537 538 # We clear the pending batches for the stream as the fetching of the 539 # missing updates below will fetch all rows in the batch. 540 self._pending_batches.pop(stream_name, []) 541 542 # Find where we previously streamed up to. 543 current_token = stream.current_token(cmd.instance_name) 544 545 # If the position token matches our current token then we're up to 546 # date and there's nothing to do. Otherwise, fetch all updates 547 # between then and now. 548 missing_updates = cmd.prev_token != current_token 549 while missing_updates: 550 logger.info( 551 "Fetching replication rows for '%s' between %i and %i", 552 stream_name, 553 current_token, 554 cmd.new_token, 555 ) 556 (updates, current_token, missing_updates) = await stream.get_updates_since( 557 cmd.instance_name, current_token, cmd.new_token 558 ) 559 560 # TODO: add some tests for this 561 562 # Some streams return multiple rows with the same stream IDs, 563 # which need to be processed in batches. 564 565 for token, rows in _batch_updates(updates): 566 await self.on_rdata( 567 stream_name, 568 cmd.instance_name, 569 token, 570 [stream.parse_row(row) for row in rows], 571 ) 572 573 logger.info("Caught up with stream '%s' to %i", stream_name, cmd.new_token) 574 575 # We've now caught up to position sent to us, notify handler. 576 await self._replication_data_handler.on_position( 577 cmd.stream_name, cmd.instance_name, cmd.new_token 578 ) 579 580 self._streams_by_connection.setdefault(conn, set()).add(stream_name) 581 582 def on_REMOTE_SERVER_UP( 583 self, conn: IReplicationConnection, cmd: RemoteServerUpCommand 584 ): 585 """Called when get a new REMOTE_SERVER_UP command.""" 586 self._replication_data_handler.on_remote_server_up(cmd.data) 587 588 self._notifier.notify_remote_server_up(cmd.data) 589 590 # We relay to all other connections to ensure every instance gets the 591 # notification. 592 # 593 # When configured to use redis we'll always only have one connection and 594 # so this is a no-op (all instances will have already received the same 595 # REMOTE_SERVER_UP command). 596 # 597 # For direct TCP connections this will relay to all other connections 598 # connected to us. When on master this will correctly fan out to all 599 # other direct TCP clients and on workers there'll only be the one 600 # connection to master. 601 # 602 # (The logic here should also be sound if we have a mix of Redis and 603 # direct TCP connections so long as there is only one traffic route 604 # between two instances, but that is not currently supported). 605 self.send_command(cmd, ignore_conn=conn) 606 607 def new_connection(self, connection: IReplicationConnection): 608 """Called when we have a new connection.""" 609 self._connections.append(connection) 610 611 # If we are connected to replication as a client (rather than a server) 612 # we need to reset the reconnection delay on the client factory (which 613 # is used to do exponential back off when the connection drops). 614 # 615 # Ideally we would reset the delay when we've "fully established" the 616 # connection (for some definition thereof) to stop us from tightlooping 617 # on reconnection if something fails after this point and we drop the 618 # connection. Unfortunately, we don't really have a better definition of 619 # "fully established" than the connection being established. 620 if self._factory: 621 self._factory.resetDelay() 622 623 # Tell the other end if we have any users currently syncing. 624 currently_syncing = ( 625 self._presence_handler.get_currently_syncing_users_for_replication() 626 ) 627 628 now = self._clock.time_msec() 629 for user_id in currently_syncing: 630 connection.send_command( 631 UserSyncCommand(self._instance_id, user_id, True, now) 632 ) 633 634 def lost_connection(self, connection: IReplicationConnection): 635 """Called when a connection is closed/lost.""" 636 # we no longer need _streams_by_connection for this connection. 637 streams = self._streams_by_connection.pop(connection, None) 638 if streams: 639 logger.info( 640 "Lost replication connection; streams now disconnected: %s", streams 641 ) 642 try: 643 self._connections.remove(connection) 644 except ValueError: 645 pass 646 647 def connected(self) -> bool: 648 """Do we have any replication connections open? 649 650 Is used by e.g. `ReplicationStreamer` to no-op if nothing is connected. 651 """ 652 return bool(self._connections) 653 654 def send_command( 655 self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None 656 ): 657 """Send a command to all connected connections. 658 659 Args: 660 cmd 661 ignore_conn: If set don't send command to the given connection. 662 Used when relaying commands from one connection to all others. 663 """ 664 if self._connections: 665 for connection in self._connections: 666 if connection == ignore_conn: 667 continue 668 669 try: 670 connection.send_command(cmd) 671 except Exception: 672 # We probably want to catch some types of exceptions here 673 # and log them as warnings (e.g. connection gone), but I 674 # can't find what those exception types they would be. 675 logger.exception( 676 "Failed to write command %s to connection %s", 677 cmd.NAME, 678 connection, 679 ) 680 else: 681 logger.warning("Dropping command as not connected: %r", cmd.NAME) 682 683 def send_federation_ack(self, token: int): 684 """Ack data for the federation stream. This allows the master to drop 685 data stored purely in memory. 686 """ 687 self.send_command(FederationAckCommand(self._instance_name, token)) 688 689 def send_user_sync( 690 self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int 691 ): 692 """Poke the master that a user has started/stopped syncing.""" 693 self.send_command( 694 UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) 695 ) 696 697 def send_user_ip( 698 self, 699 user_id: str, 700 access_token: str, 701 ip: str, 702 user_agent: str, 703 device_id: str, 704 last_seen: int, 705 ): 706 """Tell the master that the user made a request.""" 707 cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) 708 self.send_command(cmd) 709 710 def send_remote_server_up(self, server: str): 711 self.send_command(RemoteServerUpCommand(server)) 712 713 def stream_update(self, stream_name: str, token: str, data: Any): 714 """Called when a new update is available to stream to clients. 715 716 We need to check if the client is interested in the stream or not 717 """ 718 self.send_command(RdataCommand(stream_name, self._instance_name, token, data)) 719 720 721UpdateToken = TypeVar("UpdateToken") 722UpdateRow = TypeVar("UpdateRow") 723 724 725def _batch_updates( 726 updates: Iterable[Tuple[UpdateToken, UpdateRow]] 727) -> Iterator[Tuple[UpdateToken, List[UpdateRow]]]: 728 """Collect stream updates with the same token together 729 730 Given a series of updates returned by Stream.get_updates_since(), collects 731 the updates which share the same stream_id together. 732 733 For example: 734 735 [(1, a), (1, b), (2, c), (3, d), (3, e)] 736 737 becomes: 738 739 [ 740 (1, [a, b]), 741 (2, [c]), 742 (3, [d, e]), 743 ] 744 """ 745 746 update_iter = iter(updates) 747 748 first_update = next(update_iter, None) 749 if first_update is None: 750 # empty input 751 return 752 753 current_batch_token = first_update[0] 754 current_batch = [first_update[1]] 755 756 for token, row in update_iter: 757 if token != current_batch_token: 758 # different token to the previous row: flush the previous 759 # batch and start anew 760 yield current_batch_token, current_batch 761 current_batch_token = token 762 current_batch = [] 763 764 current_batch.append(row) 765 766 # flush the final batch 767 yield current_batch_token, current_batch 768