1import asyncio 2import binascii 3import functools 4import ssl 5from collections import OrderedDict 6from typing import Awaitable # noqa 7from typing import ClassVar # noqa 8from typing import Dict # noqa 9from typing import List # noqa 10from typing import Set # noqa 11from typing import ( 12 Any, 13 Coroutine, 14 Iterable, 15 Mapping, 16 Optional, 17 Sequence, 18 Tuple, 19 Type, 20 TypeVar, 21 Union, 22 cast, 23) 24 25import websockets 26from websockets.typing import Subprotocol 27 28from . import util 29from .common import ( 30 COOKIE_LENGTH, 31 INITIATOR_ADDRESS, 32 KEY_LENGTH, 33 NONCE_LENGTH, 34 RELAY_TIMEOUT, 35 AddressType, 36 ClientAddress, 37 ClientState, 38 CloseCode, 39 ResponderAddress, 40 SubProtocol, 41) 42from .events import ( 43 Event, 44 EventRegistry, 45) 46from .exception import ( 47 Disconnected, 48 DowngradeError, 49 InternalError, 50 MessageError, 51 MessageFlowError, 52 PathError, 53 PingTimeoutError, 54 ServerKeyError, 55 SignalingError, 56 SlotsFullError, 57) 58from .message import ( 59 ClientAuthMessage, 60 ClientHelloMessage, 61 DisconnectedMessage, 62 DropResponderMessage, 63 NewInitiatorMessage, 64 NewResponderMessage, 65 RelayMessage, 66 SendErrorMessage, 67 ServerAuthMessage, 68 ServerHelloMessage, 69) 70from .protocol import ( 71 Path, 72 PathClient, 73) 74from .typing import ( 75 ChosenSubProtocol, 76 DisconnectedData, 77 EventCallback, 78 EventData, 79 InitiatorPublicPermanentKey, 80 ListOrTuple, 81 MessageId, 82 NoReturn, 83 PathHex, 84 ResponderPublicSessionKey, 85 Result, 86 ServerCookie, 87 ServerPublicPermanentKey, 88 ServerSecretPermanentKey, 89) 90 91__all__ = ( 92 'serve', 93 'ServerProtocol', 94 'Paths', 95 'Server', 96) 97 98# Constants 99_JOB_QUEUE_JOIN_TIMEOUT = 10.0 100 101# Do not export! 102ST = TypeVar('ST', bound='Server') 103CloseFuture = Union['asyncio.Future[None]', Coroutine[Any, Any, None]] 104Keys = Mapping[ServerPublicPermanentKey, ServerSecretPermanentKey] 105 106 107async def serve( 108 ssl_context: Optional[ssl.SSLContext], 109 keys: Optional[Sequence[ServerSecretPermanentKey]], 110 paths: Optional['Paths'] = None, 111 host: Optional[str] = None, 112 port: int = 8765, 113 loop: Optional[asyncio.AbstractEventLoop] = None, 114 event_callbacks: Optional[Mapping[Event, Iterable[EventCallback]]] = None, 115 server_class: Optional[Type[ST]] = None, 116 ws_kwargs: Optional[Mapping[str, Any]] = None, 117) -> ST: 118 """ 119 Start serving SaltyRTC Signalling Clients. 120 121 Arguments: 122 - `ssl_context`: An `ssl.SSLContext` instance for WSS. 123 - `keys`: A sorted sequence of :class:`libnacl.public.SecretKey` 124 instances containing permanent private keys of the server. 125 The first key will be designated as the primary key. 126 - `paths`: A :class:`Paths` instance that maps path names to 127 :class:`Path` instances. Can be used to share paths on 128 multiple WebSockets. Defaults to an empty paths instance. 129 - `host`: The hostname or IP address the server will listen on. 130 Defaults to all interfaces. 131 - `port`: The port the client should connect to. Defaults to 132 `8765`. 133 - `loop`: A :class:`asyncio.BaseEventLoop` instance or `None` 134 if the default event loop should be used. 135 - `event_callbacks`: An optional dict with keys being an 136 :class:`Event` and the value being a list of callback 137 coroutines. The callback will be called every time the event 138 occurs. 139 - `server_class`: An optional :class:`Server` class to create 140 an instance from. 141 - `ws_kwargs`: Additional keyword arguments passed to 142 :func:`websockets.server.serve`. Note that the fields `ssl`, 143 `host`, `port`, `loop`, `subprotocols`, `ping_interval` and 144 `select_subprotocol` will be overridden. 145 146 If the `compression` field is not explicitly set, 147 compression will be disabled (since the data to be compressed 148 is already encrypted, compression will have little to no 149 positive effect). 150 151 Raises :exc:`ServerKeyError` in case one or more keys have been repeated. 152 """ 153 if loop is None: 154 loop = asyncio.get_event_loop() 155 156 # Create paths if not given 157 if paths is None: 158 paths = Paths() 159 160 # Create server 161 if server_class is None: 162 server_class = cast('Type[ST]', Server) 163 server = server_class(keys, paths, loop=loop) 164 165 # Register event callbacks 166 if event_callbacks is not None: 167 for event, callbacks in event_callbacks.items(): 168 for callback in callbacks: 169 server.register_event_callback(event, callback) 170 171 # Prepare arguments for the WS server 172 if ws_kwargs is None: 173 ws_kwargs = {} 174 else: 175 ws_kwargs = dict(ws_kwargs) 176 ws_kwargs['ssl'] = ssl_context 177 ws_kwargs['host'] = host 178 ws_kwargs['port'] = port 179 ws_kwargs.setdefault('compression', None) 180 ws_kwargs['ping_interval'] = None # Disable the keep-alive of the transport library 181 ws_kwargs['subprotocols'] = server.subprotocols 182 ws_kwargs['select_subprotocol'] = server.protocol_class.select_subprotocol 183 184 # Start WS server 185 ws_server = await websockets.serve(server.handler, **ws_kwargs) 186 187 # Set WS server instance 188 server.server = ws_server 189 190 # Return server 191 return server 192 193 194class ServerProtocol: 195 PATH_LENGTH = KEY_LENGTH * 2 # type: ClassVar[int] 196 197 __slots__ = ( 198 '_log', 199 '_loop', 200 '_server', 201 'subprotocol', 202 'path', 203 'client', 204 'handler_task' 205 ) 206 207 @classmethod 208 def select_subprotocol( 209 cls, 210 client_subprotocols: Sequence[Subprotocol], 211 server_subprotocols: Sequence[Subprotocol], 212 ) -> Optional[Subprotocol]: 213 # Determine common subprotocols 214 subprotocols = set(client_subprotocols) & set(server_subprotocols) 215 if len(subprotocols) == 0: 216 return None 217 218 # Sort by combined index 219 def _priority(subprotocol: Subprotocol) -> int: 220 return (client_subprotocols.index(subprotocol) + 221 server_subprotocols.index(subprotocol)) 222 return sorted(subprotocols, key=_priority)[0] 223 224 def __init__( 225 self, 226 server: 'Server', 227 subprotocol: SubProtocol, 228 connection: websockets.WebSocketServerProtocol, 229 ws_path: str, 230 loop: Optional[asyncio.AbstractEventLoop] = None, 231 ) -> None: 232 self._log = util.get_logger('server.protocol') 233 self._loop = asyncio.get_event_loop() if loop is None else loop 234 235 # Server instance and subprotocol 236 self._server = server 237 self.subprotocol = subprotocol 238 239 # Path and client instance 240 self.path = None # type: Optional[Path] 241 self.client = None # type: Optional[PathClient] 242 self._log.debug('New connection on WS path {}', ws_path) 243 244 # Get path and client instance as early as possible 245 try: 246 path, client = self.get_path_client(connection, ws_path) 247 except PathError as exc: 248 self._log.notice('Closing due to path error: {}', exc) 249 250 async def close_with_protocol_error() -> None: 251 await connection.close(code=CloseCode.protocol_error.value) 252 self._server.notify_disconnected( 253 None, DisconnectedData(CloseCode.protocol_error.value)) 254 handler_coroutine = close_with_protocol_error() 255 else: 256 handler_coroutine = self.handler() 257 client.log.info('Connection established') 258 client.log.debug('Worker started') 259 260 # Store path and client 261 self.path = path 262 self.client = client 263 self._server.register(self) 264 265 # Start handler task 266 log_handler = functools.partial( 267 self._log.exception, 'Unhandled exception in protocol handler:') 268 # noinspection PyTypeChecker 269 self.handler_task = self._loop.create_task( 270 util.log_exception(handler_coroutine, log_handler)) 271 272 async def handler(self) -> None: 273 client, path = self.client, self.path 274 assert client is not None 275 assert path is not None 276 277 # Handle client until disconnected or an exception occurred 278 hex_path = PathHex(binascii.hexlify(path.initiator_key).decode('ascii')) 279 close_future = asyncio.Future(loop=self._loop) # type: asyncio.Future[None] 280 try: 281 await self.handle_client() 282 except Disconnected as exc: 283 client.log.info('Connection closed (code: {})', exc.reason) 284 close_future.set_result(None) 285 close_awaitable = close_future # type: Awaitable[None] 286 self._server.notify_disconnected(hex_path, DisconnectedData(exc.reason)) 287 except PingTimeoutError: 288 client.log.info('Closing because of a ping timeout') 289 close_awaitable = client.close(CloseCode.timeout.value) 290 self._server.notify_disconnected( 291 hex_path, DisconnectedData(CloseCode.timeout.value)) 292 except SlotsFullError as exc: 293 client.log.notice('Closing because all path slots are full: {}', exc) 294 close_awaitable = client.close(code=CloseCode.path_full_error.value) 295 self._server.notify_disconnected( 296 hex_path, DisconnectedData(CloseCode.path_full_error.value)) 297 except ServerKeyError as exc: 298 client.log.notice('Closing due to server key error: {}', exc) 299 close_awaitable = client.close(code=CloseCode.invalid_key.value) 300 self._server.notify_disconnected( 301 hex_path, DisconnectedData(CloseCode.invalid_key.value)) 302 except InternalError as exc: 303 client.log.exception('Closing due to an internal error:', exc) 304 close_awaitable = client.close(code=CloseCode.internal_error.value) 305 self._server.notify_disconnected( 306 hex_path, DisconnectedData(CloseCode.internal_error.value)) 307 except SignalingError as exc: 308 client.log.notice('Closing due to protocol error: {}', exc) 309 close_awaitable = client.close(code=CloseCode.protocol_error.value) 310 self._server.notify_disconnected( 311 hex_path, DisconnectedData(CloseCode.protocol_error.value)) 312 except Exception as exc: 313 client.log.exception('Closing due to exception:', exc) 314 close_awaitable = client.close(code=CloseCode.internal_error.value) 315 self._server.notify_disconnected( 316 hex_path, DisconnectedData(CloseCode.internal_error.value)) 317 else: 318 # Note: This should not ever happen since 'handle_client' 319 # contains an infinite loop that only stops due to an exception. 320 client.log.error('Client closed without exception') 321 close_future.set_result(None) 322 close_awaitable = close_future 323 324 # Schedule closing of the client 325 # Note: This ensures the client is closed soon even if the job queue is holding 326 # us up. 327 if not isinstance(close_awaitable, asyncio.Future): 328 log_handler = functools.partial( 329 self._log.exception, 'Unhandled exception in closing procedure:') 330 # noinspection PyTypeChecker 331 close_awaitable = self._loop.create_task( 332 util.log_exception(close_awaitable, log_handler)) 333 334 # Wait until all queued jobs have been processed and the job queue runner 335 # returned. 336 # 337 # Note: This ensure that a send-error message (and potentially other messages) 338 # are enqueued towards other clients before the disconnect message. 339 try: 340 await asyncio.wait_for( 341 client.jobs.join(), _JOB_QUEUE_JOIN_TIMEOUT, loop=self._loop) 342 except asyncio.TimeoutError: 343 client.log.error( 344 'Job queue did not complete within {} seconds', _JOB_QUEUE_JOIN_TIMEOUT) 345 else: 346 client.log.debug('Job queue completed') 347 348 # Send disconnected message if the path is not empty and client was authenticated 349 if path.empty: 350 description = 'Skipping potential disconnected message as the path has ' \ 351 'already been detached' 352 client.log.debug(description) 353 elif client.state != ClientState.authenticated: 354 client.log.debug( 355 'Skipping potential disconnected message due to {} state', 356 client.state.name) 357 else: 358 # Initiator: Send to all responders 359 if client.type == AddressType.initiator: 360 responder_ids = path.get_responder_ids() 361 coroutines = [] # type: List[Coroutine[Any, Any, None]] 362 for responder_id in responder_ids: 363 responder = path.get_responder(responder_id) 364 365 # Create message and add send coroutine to job queue of the responder 366 message = DisconnectedMessage.create( 367 responder_id, INITIATOR_ADDRESS) 368 responder.log.debug('Enqueueing disconnected message') 369 coroutines.append(responder.jobs.enqueue(responder.send(message))) 370 try: 371 await asyncio.gather(*coroutines, loop=self._loop) 372 except Exception as exc: 373 description = 'Error while dispatching disconnected messages to ' \ 374 'responders:' 375 client.log.exception(description, exc) 376 # Responder: Send to initiator (if present) 377 elif client.type == AddressType.responder: 378 try: 379 initiator = path.get_initiator() 380 except KeyError: 381 pass # No initiator present 382 else: 383 # Create message and add send coroutine to job queue of the 384 # initiator 385 message = DisconnectedMessage.create( 386 INITIATOR_ADDRESS, ResponderAddress(client.id)) 387 initiator.log.debug('Enqueueing disconnected message') 388 try: 389 await initiator.jobs.enqueue(initiator.send(message)) 390 except Exception as exc: 391 description = 'Error while dispatching disconnected message' \ 392 'to initiator:' 393 client.log.exception(description, exc) 394 else: 395 client.log.error('Invalid address type: {}', client.type) 396 397 # Wait for the connection to be closed 398 await close_awaitable 399 client.log.debug('WS connection closed') 400 401 # Remove protocol from server and stop 402 self._server.unregister(self) 403 client.log.debug('Worker stopped') 404 405 def close(self, code: CloseCode) -> None: 406 """ 407 Close the underlying connection and stop the protocol. 408 409 Arguments: 410 - `code`: The close code. 411 """ 412 # Note: The client will be set as early as possible without any yielding. 413 # Thus, self.client is either set and can be closed or the connection 414 # is already closing (see the constructor and 'get_path_client') 415 if self.client is not None: 416 # We need to use 'drop' in order to prevent the server from sending a 417 # 'disconnect' message for each client. 418 try: 419 self._drop_client(self.client, code) 420 except KeyError: 421 # We can safely ignore this since clients will be removed immediately 422 # from the path in case they are being dropped by another client. 423 pass 424 except ValueError: 425 # We can also safely ignore this since a client may have already removed 426 # itself from the path. 427 pass 428 429 def get_path_client( 430 self, 431 connection: websockets.WebSocketServerProtocol, 432 ws_path: str, 433 ) -> Tuple[Path, PathClient]: 434 # Extract public key from path 435 initiator_key_hex = ws_path[1:] 436 437 # Validate key 438 if len(initiator_key_hex) != self.PATH_LENGTH: 439 raise PathError('Invalid path length: {}'.format(len(initiator_key_hex))) 440 try: 441 initiator_key = InitiatorPublicPermanentKey( 442 binascii.unhexlify(initiator_key_hex)) 443 except (binascii.Error, ValueError) as exc: 444 raise PathError('Could not unhexlify path') from exc 445 446 # Get path instance 447 path = self._server.paths.get(initiator_key) 448 449 # Create client instance 450 client = PathClient(connection, path.number, initiator_key, loop=self._loop) 451 452 # Attach client to path as 'pending' 453 path.add_pending(client) 454 455 # Return path and client 456 return path, client 457 458 async def handle_client(self) -> None: 459 """ 460 SignalingError 461 PathError 462 Disconnected 463 MessageError 464 MessageFlowError 465 SlotsFullError 466 DowngradeError 467 ServerKeyError 468 InternalError 469 """ 470 path, client = self.path, self.client 471 assert path is not None 472 assert client is not None 473 tasks = set() # type: Set[Coroutine[Any, Any, None]] 474 475 # Do handshake 476 client.log.debug('Starting handshake') 477 try: 478 await self.handshake() 479 except Exception as exc: 480 client.log.info('Handshake aborted') 481 482 # Encountered an exception during the handshake. 483 # Note: We already know the result (the exception), so we can cancel both 484 # job queue and tasks. 485 result = Result(exc) 486 client.jobs.cancel(result) 487 client.tasks.cancel(result) 488 else: 489 # Check if the client is still connected to the path or has already been 490 # dropped. 491 # 492 # Note: This can happen when the client is being picked up and dropped by 493 # another client while running the handshake. To prevent other race 494 # conditions, we have to add the client instance to the path early 495 # during the handshake. 496 is_connected = path.has_client(client) 497 if is_connected: 498 client.log.info('Handshake completed') 499 else: 500 client.log.info('Handshake completed but client already dropped') 501 502 # Task: Poll for messages 503 hex_path = PathHex(binascii.hexlify(path.initiator_key).decode('ascii')) 504 if client.type == AddressType.initiator: 505 self._server.notify_initiator_connected(hex_path) 506 if is_connected: 507 client.log.debug('Starting runner for initiator') 508 tasks.add(self.initiator_receive_loop()) 509 elif client.type == AddressType.responder: 510 self._server.notify_responder_connected(hex_path) 511 if is_connected: 512 client.log.debug('Starting runner for responder') 513 tasks.add(self.responder_receive_loop()) 514 else: 515 raise ValueError('Invalid address type: {}'.format(client.type)) 516 517 # Task: Keep alive 518 if is_connected: 519 client.log.debug('Starting keep-alive task') 520 tasks.add(self.keep_alive_loop()) 521 522 # Start the tasks and the job queue runner 523 client.jobs.start(client.tasks.cancel) 524 client.tasks.start(tasks) 525 526 # Wait until complete 527 # Note: This method ensures us that all tasks have been cancelled 528 # when it returns. 529 result = await client.tasks.await_result() 530 531 # Cancel pending jobs and remove client from path 532 # Note: Removing the client needs to be done here since the re-raise hands 533 # the task back into the event loop allowing other tasks to get the 534 # client's path instance from the path while it is already effectively 535 # disconnected. 536 client.jobs.cancel(result) 537 try: 538 path.remove_client(client) 539 except KeyError: 540 # We can safely ignore this since clients will be removed immediately 541 # from the path in case they are being dropped by another client. 542 pass 543 except ValueError: 544 # We can also safely ignore this since a client may have already removed 545 # itself from the path. 546 pass 547 548 # Clean the path (if still attached) 549 if path.attached: 550 self._server.paths.clean(path) 551 552 # Done! Raise the result 553 raise result 554 555 async def handshake(self) -> None: 556 """ 557 Disconnected 558 MessageError 559 MessageFlowError 560 SlotsFullError 561 DowngradeError 562 ServerKeyError 563 """ 564 client = self.client 565 assert client is not None 566 567 # Send server-hello 568 server_hello = ServerHelloMessage.create( 569 ServerPublicPermanentKey(client.server_key.pk)) 570 client.log.debug('Sending server-hello') 571 await client.send(server_hello) 572 573 # Receive client-hello or client-auth 574 client.log.debug('Waiting for client-hello or client-auth') 575 client_auth = await client.receive() 576 if isinstance(client_auth, ClientAuthMessage): 577 client.log.debug('Received client-auth') 578 # Client is the initiator 579 client.type = AddressType.initiator 580 await self.handshake_initiator(client_auth) 581 elif isinstance(client_auth, ClientHelloMessage): 582 client.log.debug('Received client-hello') 583 # Client is a responder 584 client.type = AddressType.responder 585 await self.handshake_responder(client_auth) 586 else: 587 error = "Expected 'client-hello' or 'client-auth', got '{}'" 588 raise MessageFlowError(error.format(client_auth.type)) 589 590 async def handshake_initiator(self, client_auth: ClientAuthMessage) -> None: 591 """ 592 Disconnected 593 MessageError 594 MessageFlowError 595 DowngradeError 596 ServerKeyError 597 """ 598 path, initiator = self.path, self.client 599 assert path is not None 600 assert initiator is not None 601 602 # Handle client-auth 603 self._handle_client_auth(client_auth) 604 605 # Authenticated 606 previous_initiator = path.set_initiator(initiator) 607 if previous_initiator is not None: 608 # Drop previous initiator using its job queue 609 path.log.debug('Dropping previous initiator {}', previous_initiator) 610 previous_initiator.log.debug('Dropping (another initiator connected)') 611 self._drop_client(previous_initiator, CloseCode.drop_by_initiator) 612 613 # Send new-initiator message if any responder is present 614 responder_ids = path.get_responder_ids() 615 coroutines = [] # type: List[Coroutine[Any, Any, None]] 616 for responder_id in responder_ids: 617 responder = path.get_responder(responder_id) 618 619 # Create message and add send coroutine to job queue of the responder 620 new_initiator = NewInitiatorMessage.create(responder_id) 621 responder.log.debug('Enqueueing new-initiator message') 622 coroutines.append(responder.jobs.enqueue(responder.send(new_initiator))) 623 await asyncio.gather(*coroutines, loop=self._loop) 624 625 # Send server-auth 626 responder_ids = list(path.get_responder_ids()) 627 server_auth = ServerAuthMessage.create( 628 INITIATOR_ADDRESS, initiator.cookie_in, 629 sign_keys=len(self._server.keys) > 0, responder_ids=responder_ids) 630 initiator.log.debug('Sending server-auth including responder ids') 631 await initiator.send(server_auth) 632 633 async def handshake_responder(self, client_hello: ClientHelloMessage) -> None: 634 """ 635 Disconnected 636 MessageError 637 MessageFlowError 638 SlotsFullError 639 DowngradeError 640 ServerKeyError 641 """ 642 path, responder = self.path, self.client 643 assert path is not None 644 assert responder is not None 645 646 # Set key on client 647 responder.set_client_key( 648 ResponderPublicSessionKey(client_hello.client_public_key)) 649 650 # Receive client-auth 651 client_auth = await responder.receive() 652 if not isinstance(client_auth, ClientAuthMessage): 653 error = "Expected 'client-auth', got '{}'" 654 raise MessageFlowError(error.format(client_auth.type)) 655 656 # Handle client-auth 657 self._handle_client_auth(client_auth) 658 659 # Authenticated 660 id_ = path.add_responder(responder) 661 662 # Send new-responder message if initiator is present 663 initiator = None # type: Optional[PathClient] 664 try: 665 initiator = path.get_initiator() 666 except KeyError: 667 pass 668 else: 669 # Create message and add send coroutine to job queue of the initiator 670 new_responder = NewResponderMessage.create(id_) 671 initiator.log.debug('Enqueueing new-responder message') 672 await initiator.jobs.enqueue(initiator.send(new_responder)) 673 674 # Send server-auth 675 server_auth = ServerAuthMessage.create( 676 ResponderAddress(responder.id), responder.cookie_in, 677 sign_keys=len(self._server.keys) > 0, 678 initiator_connected=initiator is not None) 679 responder.log.debug('Sending server-auth without responder ids') 680 await responder.send(server_auth) 681 682 async def initiator_receive_loop(self) -> NoReturn: 683 path, initiator = self.path, self.client 684 assert path is not None 685 assert initiator is not None 686 while True: 687 # Receive relay message or drop-responder 688 message = await initiator.receive() 689 690 # Relay 691 if isinstance(message, RelayMessage): 692 # Lookup responder 693 responder = None # type: Optional[PathClient] 694 try: 695 responder_id = ResponderAddress(message.destination) 696 responder = path.get_responder(responder_id) 697 except KeyError: 698 pass 699 # Send to responder 700 await self.relay_message( 701 responder, ClientAddress(message.destination), message) 702 # Drop-responder 703 elif isinstance(message, DropResponderMessage): 704 # Lookup responder 705 try: 706 responder = path.get_responder(message.responder_id) 707 except KeyError: 708 log_message = 'Responder {} already dropped, nothing to do' 709 path.log.debug(log_message, message.responder_id) 710 else: 711 # Drop responder using its job queue 712 path.log.debug( 713 'Dropping responder {}, reason: {}', responder, message.reason) 714 responder.log.debug( 715 'Dropping (requested by initiator), reason: {}', message.reason) 716 self._drop_client(responder, CloseCode(message.reason)) 717 else: 718 error = "Expected relay message or 'drop-responder', got '{}'" 719 raise MessageFlowError(error.format(message.type)) 720 721 async def responder_receive_loop(self) -> NoReturn: 722 path, responder = self.path, self.client 723 assert path is not None 724 assert responder is not None 725 while True: 726 # Receive relay message 727 message = await responder.receive() 728 729 # Relay 730 if isinstance(message, RelayMessage): 731 # Lookup initiator 732 initiator = None # type: Optional[PathClient] 733 try: 734 initiator = path.get_initiator() 735 except KeyError: 736 pass 737 # Send to initiator 738 await self.relay_message(initiator, INITIATOR_ADDRESS, message) 739 else: 740 error = "Expected relay message, got '{}'" 741 raise MessageFlowError(error.format(message.type)) 742 743 async def relay_message( 744 self, 745 destination: Optional[PathClient], 746 destination_id: ClientAddress, 747 message: RelayMessage, 748 ) -> None: 749 source = self.client 750 assert source is not None 751 752 # Prepare message 753 source.log.debug('Packing relay message') 754 message_id = MessageId(message.pack(source)[COOKIE_LENGTH:NONCE_LENGTH]) 755 756 async def send_error_message() -> None: 757 assert source is not None 758 # Create message and add send coroutine to job queue of the source 759 error = SendErrorMessage.create(ClientAddress(source.id), message_id) 760 source.log.info('Relaying failed, enqueuing send-error') 761 await source.jobs.enqueue(source.send(error)) 762 763 # Destination not connected? Send 'send-error' to source 764 if destination is None: 765 error_message = ('Cannot relay message, no connection for ' 766 'destination id 0x{:02x}') 767 source.log.info(error_message, destination_id) 768 await send_error_message() 769 return 770 771 # Add send task to job queue of the destination 772 task = self._loop.create_task(destination.send(message)) 773 destination.log.debug('Enqueueing relayed message from 0x{:02x}', source.id) 774 await destination.jobs.enqueue(task) 775 776 # noinspection PyBroadException 777 try: 778 # Wait for send task to complete 779 await asyncio.wait_for(task, RELAY_TIMEOUT, loop=self._loop) 780 except asyncio.TimeoutError: 781 # Timed out, send 'send-error' to source 782 log_message = 'Sending relayed message to 0x{:02x} timed out' 783 source.log.info(log_message, destination_id) 784 await send_error_message() 785 except Exception as exc: 786 # Handle cancellation of the client 787 if isinstance(exc, asyncio.CancelledError) and source.tasks.have_result: 788 raise 789 790 # An exception has been triggered while sending the message. 791 # Note: We don't care about the actual exception as the job 792 # queue runner will also trigger that exception on the 793 # destination client's handler who will log what happened. 794 log_message = 'Sending relayed message failed, receiver 0x{:02x} is gone' 795 source.log.info(log_message, destination_id) 796 await send_error_message() 797 else: 798 source.log.debug('Sending relayed message to 0x{:02x} successful', 799 destination.id) 800 801 async def keep_alive_loop(self) -> NoReturn: 802 """ 803 Disconnected 804 PingTimeoutError 805 """ 806 client = self.client 807 assert client is not None 808 while True: 809 # Wait 810 # noinspection PyTypeChecker 811 await asyncio.sleep(client.keep_alive_interval, loop=self._loop) 812 813 # Send ping and wait for pong 814 client.log.debug('Ping') 815 pong_future = await client.ping() 816 try: 817 await asyncio.wait_for( 818 client.wait_pong(pong_future), client.keep_alive_timeout, 819 loop=self._loop) 820 except asyncio.TimeoutError: 821 client.log.debug('Ping timed out') 822 raise PingTimeoutError(str(client)) 823 else: 824 client.log.debug('Pong') 825 client.keep_alive_pings += 1 826 827 def _handle_client_auth(self, client_auth: ClientAuthMessage) -> None: 828 """ 829 MessageError 830 DowngradeError 831 ServerKeyError 832 """ 833 client = self.client 834 assert client is not None 835 836 # Validate cookie and ensure no sub-protocol downgrade took place 837 self._validate_cookie(client_auth.server_cookie, client.cookie_out) 838 self._validate_subprotocol(client_auth.subprotocols) 839 840 # Set the keep alive interval (if any) 841 if client_auth.ping_interval is not None: 842 client.log.debug( 843 'Setting keep-alive interval to {}', client_auth.ping_interval) 844 client.keep_alive_interval = client_auth.ping_interval 845 846 # Set the public permanent key the client wants to use (or fallback to primary) 847 server_keys_count = len(self._server.keys) 848 if client_auth.server_key is not None: 849 # No permanent key pair? 850 if server_keys_count == 0: 851 raise ServerKeyError('Server does not have a permanent public key') 852 853 # Find the key instance 854 server_key = self._server.keys.get(client_auth.server_key) 855 if server_key is None: 856 raise ServerKeyError( 857 'Server does not have the requested permanent public key') 858 859 # Set the key instance on the client 860 client.server_permanent_key = server_key 861 elif server_keys_count > 0: 862 # Use primary permanent key 863 client.server_permanent_key = next(iter(self._server.keys.values())) 864 865 def _validate_cookie( 866 self, 867 expected_cookie: ServerCookie, 868 actual_cookie: ServerCookie, 869 ) -> None: 870 """ 871 MessageError 872 """ 873 client = self.client 874 assert client is not None 875 client.log.debug('Validating cookie') 876 if not util.consteq(expected_cookie, actual_cookie): 877 raise MessageError('Cookies do not match') 878 879 def _validate_subprotocol( 880 self, 881 client_subprotocols: ListOrTuple[ChosenSubProtocol], 882 ) -> None: 883 """ 884 MessageError 885 DowngradeError 886 """ 887 client = self.client 888 assert client is not None 889 client.log.debug( 890 'Checking for subprotocol downgrade, client: {}, server: {}', 891 client_subprotocols, self._server.subprotocols) 892 chosen = self.select_subprotocol( 893 cast(Sequence[Subprotocol], client_subprotocols), 894 cast(Sequence[Subprotocol], self._server.subprotocols)) 895 if chosen != self.subprotocol.value: 896 raise DowngradeError('Subprotocol downgrade detected') 897 898 def _drop_client(self, client: PathClient, code: CloseCode) -> None: 899 """ 900 Mark the client as closed, schedule the closing procedure on 901 the client's job queue and remove it from the path. 902 903 .. important:: This should only be called by clients dropping 904 another client or when the server is closing. 905 906 Arguments: 907 - `client`: The client to be dropped. 908 - `close`: The close code. 909 910 Raises: 911 - :exc:`KeyError` in case the client is not attached to the 912 path. 913 - :exc:`ValueError` in case the path has been detached. 914 """ 915 # Drop the client 916 client.drop(code) 917 918 # Remove the client from the path 919 path = self.path 920 assert path is not None 921 path.remove_client(client) 922 923 924class Paths: 925 __slots__ = ('_log', 'number', 'paths') 926 927 def __init__(self) -> None: 928 self._log = util.get_logger('paths') 929 self.number = 0 930 self.paths = {} # type: Dict[InitiatorPublicPermanentKey, Path] 931 932 def get(self, initiator_key: InitiatorPublicPermanentKey) -> Path: 933 if self.paths.get(initiator_key) is None: 934 self.number += 1 935 self.paths[initiator_key] = Path(initiator_key, self.number, attached=True) 936 self._log.debug('Created new path: {}', self.number) 937 return self.paths[initiator_key] 938 939 def clean(self, path: Path) -> None: 940 if path.empty: 941 path.attached = False 942 try: 943 del self.paths[path.initiator_key] 944 except KeyError: 945 self._log.error('Path {} has already been removed', path.number) 946 else: 947 self._log.debug('Removed empty path: {}', path.number) 948 path.clear() 949 950 951class Server: 952 subprotocols = [ 953 SubProtocol.saltyrtc_v1.value 954 ] # type: ClassVar[Sequence[SubProtocol]] 955 956 def __init__( 957 self, 958 keys: Optional[Sequence[ServerSecretPermanentKey]], 959 paths: Paths, 960 loop: Optional[asyncio.AbstractEventLoop] = None, 961 ) -> None: 962 self._log = util.get_logger('server') 963 self._loop = asyncio.get_event_loop() if loop is None else loop 964 965 # Protocol class 966 self.protocol_class = ServerProtocol # type: Type[ServerProtocol] 967 968 # WebSocket server instance 969 self._server = None # type: Optional[websockets.server.WebSocketServer] 970 971 # Validate & store keys 972 if keys is None: 973 keys = [] 974 if len(keys) != len({key.pk for key in keys}): 975 raise ServerKeyError('Repeated permanent keys') 976 self.keys = OrderedDict( 977 ((ServerPublicPermanentKey(key.pk), key) for key in keys)) # type: Keys 978 979 # Store paths 980 self.paths = paths 981 982 # Store server protocols and closing task 983 self.protocols = set() # type: Set[ServerProtocol] 984 self._close_task = None # type: Optional[asyncio.Task[None]] 985 986 # Event Registry 987 self._events = EventRegistry() 988 989 @property 990 def server(self) -> websockets.server.WebSocketServer: 991 assert self._server is not None 992 return self._server 993 994 @server.setter 995 def server(self, server: websockets.server.WebSocketServer) -> None: 996 self._server = server 997 self._log.debug('Server instance: {}', server) 998 999 async def handler( 1000 self, 1001 connection: websockets.WebSocketServerProtocol, 1002 ws_path: str, 1003 ) -> None: 1004 # Closing? Drop immediately 1005 if self._close_task is not None: 1006 await connection.close(CloseCode.going_away.value) 1007 return 1008 1009 # Convert sub-protocol 1010 subprotocol = None # type: Optional[SubProtocol] 1011 try: 1012 subprotocol = SubProtocol(connection.subprotocol) 1013 except ValueError: 1014 pass 1015 1016 # Determine ServerProtocol instance by selected sub-protocol 1017 if subprotocol != SubProtocol.saltyrtc_v1: 1018 self._log.notice('Could not negotiate a sub-protocol, dropping client') 1019 # We need to close the connection manually as the client may choose 1020 # to ignore 1021 await connection.close(code=CloseCode.subprotocol_error.value) 1022 self.notify_disconnected( 1023 None, DisconnectedData(CloseCode.subprotocol_error.value)) 1024 else: 1025 assert subprotocol is not None 1026 protocol = self.protocol_class( 1027 self, subprotocol, connection, ws_path, loop=self._loop) 1028 await protocol.handler_task 1029 1030 def register(self, protocol: ServerProtocol) -> None: 1031 self.protocols.add(protocol) 1032 self._log.debug('Protocol registered: {}', protocol) 1033 1034 def unregister(self, protocol: ServerProtocol) -> None: 1035 self.protocols.remove(protocol) 1036 self._log.debug('Protocol unregistered: {}', protocol) 1037 1038 def register_event_callback(self, event: Event, callback: EventCallback) -> None: 1039 """ 1040 Register a new event callback. 1041 """ 1042 self._events.register(event, callback) 1043 1044 def notify_initiator_connected(self, path: PathHex) -> None: 1045 self._raise_event(Event.initiator_connected, path, None) 1046 1047 def notify_responder_connected(self, path: PathHex) -> None: 1048 self._raise_event(Event.responder_connected, path, None) 1049 1050 def notify_disconnected( 1051 self, 1052 path: Optional[PathHex], 1053 data: DisconnectedData, 1054 ) -> None: 1055 self._raise_event(Event.disconnected, path, data) 1056 1057 def _raise_event( 1058 self, 1059 event: Event, 1060 path: Optional[PathHex], 1061 data: EventData, 1062 ) -> None: 1063 """ 1064 Raise an event and invoke all registered event callbacks. 1065 1066 Arguments: 1067 - `event`: Event to be raised. 1068 - `path`: Associated path in hexadecimal representation or 1069 `None` if not available. 1070 - `data`: Additional data for the event as explained for 1071 :class:`EventRegistry`. 1072 """ 1073 for callback in self._events.get_callbacks(event): 1074 coroutine = callback(event, path, data) 1075 log_handler = functools.partial( 1076 self._log.exception, 'Unhandled exception in event handler:') 1077 # noinspection PyTypeChecker 1078 self._loop.create_task(util.log_exception(coroutine, log_handler)) 1079 1080 def close(self) -> None: 1081 """ 1082 Close open connections and the server. 1083 """ 1084 if self._close_task is None: 1085 log_handler = functools.partial( 1086 self._log.exception, 'Exception while closing:') 1087 # noinspection PyTypeChecker 1088 self._close_task = self._loop.create_task( 1089 util.log_exception(self._close_after_all_protocols_closed(), log_handler)) 1090 1091 async def wait_closed(self) -> None: 1092 """ 1093 Wait until all connections and the server itself has been 1094 closed. 1095 """ 1096 await self.server.wait_closed() 1097 1098 async def _close_after_all_protocols_closed( 1099 self, 1100 timeout: Optional[float] = None, 1101 ) -> None: 1102 # Schedule closing all protocols 1103 self._log.info('Closing protocols') 1104 if len(self.protocols) > 0: 1105 async def _close_and_wait() -> None: 1106 # Wait until all connections have been scheduled to be closed 1107 for protocol in self.protocols: 1108 protocol.close(CloseCode.going_away) 1109 1110 # Wait until all protocols have returned 1111 handler_tasks = [protocol.handler_task for protocol in self.protocols] 1112 await asyncio.gather(*handler_tasks, loop=self._loop) 1113 1114 await asyncio.wait_for(_close_and_wait(), timeout, loop=self._loop) 1115 1116 # Now we can close the server 1117 self._log.info('Closing server') 1118 self.server.close() 1119