1import asyncio
2import binascii
3import functools
4import ssl
5from collections import OrderedDict
6from typing import Awaitable  # noqa
7from typing import ClassVar  # noqa
8from typing import Dict  # noqa
9from typing import List  # noqa
10from typing import Set  # noqa
11from typing import (
12    Any,
13    Coroutine,
14    Iterable,
15    Mapping,
16    Optional,
17    Sequence,
18    Tuple,
19    Type,
20    TypeVar,
21    Union,
22    cast,
23)
24
25import websockets
26from websockets.typing import Subprotocol
27
28from . import util
29from .common import (
30    COOKIE_LENGTH,
31    INITIATOR_ADDRESS,
32    KEY_LENGTH,
33    NONCE_LENGTH,
34    RELAY_TIMEOUT,
35    AddressType,
36    ClientAddress,
37    ClientState,
38    CloseCode,
39    ResponderAddress,
40    SubProtocol,
41)
42from .events import (
43    Event,
44    EventRegistry,
45)
46from .exception import (
47    Disconnected,
48    DowngradeError,
49    InternalError,
50    MessageError,
51    MessageFlowError,
52    PathError,
53    PingTimeoutError,
54    ServerKeyError,
55    SignalingError,
56    SlotsFullError,
57)
58from .message import (
59    ClientAuthMessage,
60    ClientHelloMessage,
61    DisconnectedMessage,
62    DropResponderMessage,
63    NewInitiatorMessage,
64    NewResponderMessage,
65    RelayMessage,
66    SendErrorMessage,
67    ServerAuthMessage,
68    ServerHelloMessage,
69)
70from .protocol import (
71    Path,
72    PathClient,
73)
74from .typing import (
75    ChosenSubProtocol,
76    DisconnectedData,
77    EventCallback,
78    EventData,
79    InitiatorPublicPermanentKey,
80    ListOrTuple,
81    MessageId,
82    NoReturn,
83    PathHex,
84    ResponderPublicSessionKey,
85    Result,
86    ServerCookie,
87    ServerPublicPermanentKey,
88    ServerSecretPermanentKey,
89)
90
91__all__ = (
92    'serve',
93    'ServerProtocol',
94    'Paths',
95    'Server',
96)
97
98# Constants
99_JOB_QUEUE_JOIN_TIMEOUT = 10.0
100
101# Do not export!
102ST = TypeVar('ST', bound='Server')
103CloseFuture = Union['asyncio.Future[None]', Coroutine[Any, Any, None]]
104Keys = Mapping[ServerPublicPermanentKey, ServerSecretPermanentKey]
105
106
107async def serve(
108        ssl_context: Optional[ssl.SSLContext],
109        keys: Optional[Sequence[ServerSecretPermanentKey]],
110        paths: Optional['Paths'] = None,
111        host: Optional[str] = None,
112        port: int = 8765,
113        loop: Optional[asyncio.AbstractEventLoop] = None,
114        event_callbacks: Optional[Mapping[Event, Iterable[EventCallback]]] = None,
115        server_class: Optional[Type[ST]] = None,
116        ws_kwargs: Optional[Mapping[str, Any]] = None,
117) -> ST:
118    """
119    Start serving SaltyRTC Signalling Clients.
120
121    Arguments:
122        - `ssl_context`: An `ssl.SSLContext` instance for WSS.
123        - `keys`: A sorted sequence of :class:`libnacl.public.SecretKey`
124          instances containing permanent private keys of the server.
125          The first key will be designated as the primary key.
126        - `paths`: A :class:`Paths` instance that maps path names to
127          :class:`Path` instances. Can be used to share paths on
128          multiple WebSockets. Defaults to an empty paths instance.
129        - `host`: The hostname or IP address the server will listen on.
130          Defaults to all interfaces.
131        - `port`: The port the client should connect to. Defaults to
132          `8765`.
133        - `loop`: A :class:`asyncio.BaseEventLoop` instance or `None`
134          if the default event loop should be used.
135        - `event_callbacks`: An optional dict with keys being an
136          :class:`Event` and the value being a list of callback
137          coroutines. The callback will be called every time the event
138          occurs.
139        - `server_class`: An optional :class:`Server` class to create
140          an instance from.
141        - `ws_kwargs`: Additional keyword arguments passed to
142          :func:`websockets.server.serve`. Note that the fields `ssl`,
143          `host`, `port`, `loop`, `subprotocols`, `ping_interval` and
144          `select_subprotocol` will be overridden.
145
146          If the `compression` field is not explicitly set,
147          compression will be disabled (since the data to be compressed
148          is already encrypted, compression will have little to no
149          positive effect).
150
151    Raises :exc:`ServerKeyError` in case one or more keys have been repeated.
152    """
153    if loop is None:
154        loop = asyncio.get_event_loop()
155
156    # Create paths if not given
157    if paths is None:
158        paths = Paths()
159
160    # Create server
161    if server_class is None:
162        server_class = cast('Type[ST]', Server)
163    server = server_class(keys, paths, loop=loop)
164
165    # Register event callbacks
166    if event_callbacks is not None:
167        for event, callbacks in event_callbacks.items():
168            for callback in callbacks:
169                server.register_event_callback(event, callback)
170
171    # Prepare arguments for the WS server
172    if ws_kwargs is None:
173        ws_kwargs = {}
174    else:
175        ws_kwargs = dict(ws_kwargs)
176    ws_kwargs['ssl'] = ssl_context
177    ws_kwargs['host'] = host
178    ws_kwargs['port'] = port
179    ws_kwargs.setdefault('compression', None)
180    ws_kwargs['ping_interval'] = None  # Disable the keep-alive of the transport library
181    ws_kwargs['subprotocols'] = server.subprotocols
182    ws_kwargs['select_subprotocol'] = server.protocol_class.select_subprotocol
183
184    # Start WS server
185    ws_server = await websockets.serve(server.handler, **ws_kwargs)
186
187    # Set WS server instance
188    server.server = ws_server
189
190    # Return server
191    return server
192
193
194class ServerProtocol:
195    PATH_LENGTH = KEY_LENGTH * 2  # type: ClassVar[int]
196
197    __slots__ = (
198        '_log',
199        '_loop',
200        '_server',
201        'subprotocol',
202        'path',
203        'client',
204        'handler_task'
205    )
206
207    @classmethod
208    def select_subprotocol(
209            cls,
210            client_subprotocols: Sequence[Subprotocol],
211            server_subprotocols: Sequence[Subprotocol],
212    ) -> Optional[Subprotocol]:
213        # Determine common subprotocols
214        subprotocols = set(client_subprotocols) & set(server_subprotocols)
215        if len(subprotocols) == 0:
216            return None
217
218        # Sort by combined index
219        def _priority(subprotocol: Subprotocol) -> int:
220            return (client_subprotocols.index(subprotocol) +
221                    server_subprotocols.index(subprotocol))
222        return sorted(subprotocols, key=_priority)[0]
223
224    def __init__(
225            self,
226            server: 'Server',
227            subprotocol: SubProtocol,
228            connection: websockets.WebSocketServerProtocol,
229            ws_path: str,
230            loop: Optional[asyncio.AbstractEventLoop] = None,
231    ) -> None:
232        self._log = util.get_logger('server.protocol')
233        self._loop = asyncio.get_event_loop() if loop is None else loop
234
235        # Server instance and subprotocol
236        self._server = server
237        self.subprotocol = subprotocol
238
239        # Path and client instance
240        self.path = None  # type: Optional[Path]
241        self.client = None  # type: Optional[PathClient]
242        self._log.debug('New connection on WS path {}', ws_path)
243
244        # Get path and client instance as early as possible
245        try:
246            path, client = self.get_path_client(connection, ws_path)
247        except PathError as exc:
248            self._log.notice('Closing due to path error: {}', exc)
249
250            async def close_with_protocol_error() -> None:
251                await connection.close(code=CloseCode.protocol_error.value)
252                self._server.notify_disconnected(
253                    None, DisconnectedData(CloseCode.protocol_error.value))
254            handler_coroutine = close_with_protocol_error()
255        else:
256            handler_coroutine = self.handler()
257            client.log.info('Connection established')
258            client.log.debug('Worker started')
259
260            # Store path and client
261            self.path = path
262            self.client = client
263            self._server.register(self)
264
265        # Start handler task
266        log_handler = functools.partial(
267            self._log.exception, 'Unhandled exception in protocol handler:')
268        # noinspection PyTypeChecker
269        self.handler_task = self._loop.create_task(
270            util.log_exception(handler_coroutine, log_handler))
271
272    async def handler(self) -> None:
273        client, path = self.client, self.path
274        assert client is not None
275        assert path is not None
276
277        # Handle client until disconnected or an exception occurred
278        hex_path = PathHex(binascii.hexlify(path.initiator_key).decode('ascii'))
279        close_future = asyncio.Future(loop=self._loop)  # type: asyncio.Future[None]
280        try:
281            await self.handle_client()
282        except Disconnected as exc:
283            client.log.info('Connection closed (code: {})', exc.reason)
284            close_future.set_result(None)
285            close_awaitable = close_future  # type: Awaitable[None]
286            self._server.notify_disconnected(hex_path, DisconnectedData(exc.reason))
287        except PingTimeoutError:
288            client.log.info('Closing because of a ping timeout')
289            close_awaitable = client.close(CloseCode.timeout.value)
290            self._server.notify_disconnected(
291                hex_path, DisconnectedData(CloseCode.timeout.value))
292        except SlotsFullError as exc:
293            client.log.notice('Closing because all path slots are full: {}', exc)
294            close_awaitable = client.close(code=CloseCode.path_full_error.value)
295            self._server.notify_disconnected(
296                hex_path, DisconnectedData(CloseCode.path_full_error.value))
297        except ServerKeyError as exc:
298            client.log.notice('Closing due to server key error: {}', exc)
299            close_awaitable = client.close(code=CloseCode.invalid_key.value)
300            self._server.notify_disconnected(
301                hex_path, DisconnectedData(CloseCode.invalid_key.value))
302        except InternalError as exc:
303            client.log.exception('Closing due to an internal error:', exc)
304            close_awaitable = client.close(code=CloseCode.internal_error.value)
305            self._server.notify_disconnected(
306                hex_path, DisconnectedData(CloseCode.internal_error.value))
307        except SignalingError as exc:
308            client.log.notice('Closing due to protocol error: {}', exc)
309            close_awaitable = client.close(code=CloseCode.protocol_error.value)
310            self._server.notify_disconnected(
311                hex_path, DisconnectedData(CloseCode.protocol_error.value))
312        except Exception as exc:
313            client.log.exception('Closing due to exception:', exc)
314            close_awaitable = client.close(code=CloseCode.internal_error.value)
315            self._server.notify_disconnected(
316                hex_path, DisconnectedData(CloseCode.internal_error.value))
317        else:
318            # Note: This should not ever happen since 'handle_client'
319            #       contains an infinite loop that only stops due to an exception.
320            client.log.error('Client closed without exception')
321            close_future.set_result(None)
322            close_awaitable = close_future
323
324        # Schedule closing of the client
325        # Note: This ensures the client is closed soon even if the job queue is holding
326        #       us up.
327        if not isinstance(close_awaitable, asyncio.Future):
328            log_handler = functools.partial(
329                self._log.exception, 'Unhandled exception in closing procedure:')
330            # noinspection PyTypeChecker
331            close_awaitable = self._loop.create_task(
332                util.log_exception(close_awaitable, log_handler))
333
334        # Wait until all queued jobs have been processed and the job queue runner
335        # returned.
336        #
337        # Note: This ensure that a send-error message (and potentially other messages)
338        #       are enqueued towards other clients before the disconnect message.
339        try:
340            await asyncio.wait_for(
341                client.jobs.join(), _JOB_QUEUE_JOIN_TIMEOUT, loop=self._loop)
342        except asyncio.TimeoutError:
343            client.log.error(
344                'Job queue did not complete within {} seconds', _JOB_QUEUE_JOIN_TIMEOUT)
345        else:
346            client.log.debug('Job queue completed')
347
348        # Send disconnected message if the path is not empty and client was authenticated
349        if path.empty:
350            description = 'Skipping potential disconnected message as the path has ' \
351                          'already been detached'
352            client.log.debug(description)
353        elif client.state != ClientState.authenticated:
354            client.log.debug(
355                'Skipping potential disconnected message due to {} state',
356                client.state.name)
357        else:
358            # Initiator: Send to all responders
359            if client.type == AddressType.initiator:
360                responder_ids = path.get_responder_ids()
361                coroutines = []  # type: List[Coroutine[Any, Any, None]]
362                for responder_id in responder_ids:
363                    responder = path.get_responder(responder_id)
364
365                    # Create message and add send coroutine to job queue of the responder
366                    message = DisconnectedMessage.create(
367                        responder_id, INITIATOR_ADDRESS)
368                    responder.log.debug('Enqueueing disconnected message')
369                    coroutines.append(responder.jobs.enqueue(responder.send(message)))
370                try:
371                    await asyncio.gather(*coroutines, loop=self._loop)
372                except Exception as exc:
373                    description = 'Error while dispatching disconnected messages to ' \
374                                  'responders:'
375                    client.log.exception(description, exc)
376            # Responder: Send to initiator (if present)
377            elif client.type == AddressType.responder:
378                try:
379                    initiator = path.get_initiator()
380                except KeyError:
381                    pass  # No initiator present
382                else:
383                    # Create message and add send coroutine to job queue of the
384                    # initiator
385                    message = DisconnectedMessage.create(
386                        INITIATOR_ADDRESS, ResponderAddress(client.id))
387                    initiator.log.debug('Enqueueing disconnected message')
388                    try:
389                        await initiator.jobs.enqueue(initiator.send(message))
390                    except Exception as exc:
391                        description = 'Error while dispatching disconnected message' \
392                                      'to initiator:'
393                        client.log.exception(description, exc)
394            else:
395                client.log.error('Invalid address type: {}', client.type)
396
397        # Wait for the connection to be closed
398        await close_awaitable
399        client.log.debug('WS connection closed')
400
401        # Remove protocol from server and stop
402        self._server.unregister(self)
403        client.log.debug('Worker stopped')
404
405    def close(self, code: CloseCode) -> None:
406        """
407        Close the underlying connection and stop the protocol.
408
409        Arguments:
410            - `code`: The close code.
411        """
412        # Note: The client will be set as early as possible without any yielding.
413        #       Thus, self.client is either set and can be closed or the connection
414        #       is already closing (see the constructor and 'get_path_client')
415        if self.client is not None:
416            # We need to use 'drop' in order to prevent the server from sending a
417            # 'disconnect' message for each client.
418            try:
419                self._drop_client(self.client, code)
420            except KeyError:
421                # We can safely ignore this since clients will be removed immediately
422                # from the path in case they are being dropped by another client.
423                pass
424            except ValueError:
425                # We can also safely ignore this since a client may have already removed
426                # itself from the path.
427                pass
428
429    def get_path_client(
430            self,
431            connection: websockets.WebSocketServerProtocol,
432            ws_path: str,
433    ) -> Tuple[Path, PathClient]:
434        # Extract public key from path
435        initiator_key_hex = ws_path[1:]
436
437        # Validate key
438        if len(initiator_key_hex) != self.PATH_LENGTH:
439            raise PathError('Invalid path length: {}'.format(len(initiator_key_hex)))
440        try:
441            initiator_key = InitiatorPublicPermanentKey(
442                binascii.unhexlify(initiator_key_hex))
443        except (binascii.Error, ValueError) as exc:
444            raise PathError('Could not unhexlify path') from exc
445
446        # Get path instance
447        path = self._server.paths.get(initiator_key)
448
449        # Create client instance
450        client = PathClient(connection, path.number, initiator_key, loop=self._loop)
451
452        # Attach client to path as 'pending'
453        path.add_pending(client)
454
455        # Return path and client
456        return path, client
457
458    async def handle_client(self) -> None:
459        """
460        SignalingError
461        PathError
462        Disconnected
463        MessageError
464        MessageFlowError
465        SlotsFullError
466        DowngradeError
467        ServerKeyError
468        InternalError
469        """
470        path, client = self.path, self.client
471        assert path is not None
472        assert client is not None
473        tasks = set()  # type: Set[Coroutine[Any, Any, None]]
474
475        # Do handshake
476        client.log.debug('Starting handshake')
477        try:
478            await self.handshake()
479        except Exception as exc:
480            client.log.info('Handshake aborted')
481
482            # Encountered an exception during the handshake.
483            # Note: We already know the result (the exception), so we can cancel both
484            #       job queue and tasks.
485            result = Result(exc)
486            client.jobs.cancel(result)
487            client.tasks.cancel(result)
488        else:
489            # Check if the client is still connected to the path or has already been
490            # dropped.
491            #
492            # Note: This can happen when the client is being picked up and dropped by
493            #       another client while running the handshake. To prevent other race
494            #       conditions, we have to add the client instance to the path early
495            #       during the handshake.
496            is_connected = path.has_client(client)
497            if is_connected:
498                client.log.info('Handshake completed')
499            else:
500                client.log.info('Handshake completed but client already dropped')
501
502            # Task: Poll for messages
503            hex_path = PathHex(binascii.hexlify(path.initiator_key).decode('ascii'))
504            if client.type == AddressType.initiator:
505                self._server.notify_initiator_connected(hex_path)
506                if is_connected:
507                    client.log.debug('Starting runner for initiator')
508                    tasks.add(self.initiator_receive_loop())
509            elif client.type == AddressType.responder:
510                self._server.notify_responder_connected(hex_path)
511                if is_connected:
512                    client.log.debug('Starting runner for responder')
513                    tasks.add(self.responder_receive_loop())
514            else:
515                raise ValueError('Invalid address type: {}'.format(client.type))
516
517            # Task: Keep alive
518            if is_connected:
519                client.log.debug('Starting keep-alive task')
520                tasks.add(self.keep_alive_loop())
521
522        # Start the tasks and the job queue runner
523        client.jobs.start(client.tasks.cancel)
524        client.tasks.start(tasks)
525
526        # Wait until complete
527        # Note: This method ensures us that all tasks have been cancelled
528        #       when it returns.
529        result = await client.tasks.await_result()
530
531        # Cancel pending jobs and remove client from path
532        # Note: Removing the client needs to be done here since the re-raise hands
533        #       the task back into the event loop allowing other tasks to get the
534        #       client's path instance from the path while it is already effectively
535        #       disconnected.
536        client.jobs.cancel(result)
537        try:
538            path.remove_client(client)
539        except KeyError:
540            # We can safely ignore this since clients will be removed immediately
541            # from the path in case they are being dropped by another client.
542            pass
543        except ValueError:
544            # We can also safely ignore this since a client may have already removed
545            # itself from the path.
546            pass
547
548        # Clean the path (if still attached)
549        if path.attached:
550            self._server.paths.clean(path)
551
552        # Done! Raise the result
553        raise result
554
555    async def handshake(self) -> None:
556        """
557        Disconnected
558        MessageError
559        MessageFlowError
560        SlotsFullError
561        DowngradeError
562        ServerKeyError
563        """
564        client = self.client
565        assert client is not None
566
567        # Send server-hello
568        server_hello = ServerHelloMessage.create(
569            ServerPublicPermanentKey(client.server_key.pk))
570        client.log.debug('Sending server-hello')
571        await client.send(server_hello)
572
573        # Receive client-hello or client-auth
574        client.log.debug('Waiting for client-hello or client-auth')
575        client_auth = await client.receive()
576        if isinstance(client_auth, ClientAuthMessage):
577            client.log.debug('Received client-auth')
578            # Client is the initiator
579            client.type = AddressType.initiator
580            await self.handshake_initiator(client_auth)
581        elif isinstance(client_auth, ClientHelloMessage):
582            client.log.debug('Received client-hello')
583            # Client is a responder
584            client.type = AddressType.responder
585            await self.handshake_responder(client_auth)
586        else:
587            error = "Expected 'client-hello' or 'client-auth', got '{}'"
588            raise MessageFlowError(error.format(client_auth.type))
589
590    async def handshake_initiator(self, client_auth: ClientAuthMessage) -> None:
591        """
592        Disconnected
593        MessageError
594        MessageFlowError
595        DowngradeError
596        ServerKeyError
597        """
598        path, initiator = self.path, self.client
599        assert path is not None
600        assert initiator is not None
601
602        # Handle client-auth
603        self._handle_client_auth(client_auth)
604
605        # Authenticated
606        previous_initiator = path.set_initiator(initiator)
607        if previous_initiator is not None:
608            # Drop previous initiator using its job queue
609            path.log.debug('Dropping previous initiator {}', previous_initiator)
610            previous_initiator.log.debug('Dropping (another initiator connected)')
611            self._drop_client(previous_initiator, CloseCode.drop_by_initiator)
612
613        # Send new-initiator message if any responder is present
614        responder_ids = path.get_responder_ids()
615        coroutines = []  # type: List[Coroutine[Any, Any, None]]
616        for responder_id in responder_ids:
617            responder = path.get_responder(responder_id)
618
619            # Create message and add send coroutine to job queue of the responder
620            new_initiator = NewInitiatorMessage.create(responder_id)
621            responder.log.debug('Enqueueing new-initiator message')
622            coroutines.append(responder.jobs.enqueue(responder.send(new_initiator)))
623        await asyncio.gather(*coroutines, loop=self._loop)
624
625        # Send server-auth
626        responder_ids = list(path.get_responder_ids())
627        server_auth = ServerAuthMessage.create(
628            INITIATOR_ADDRESS, initiator.cookie_in,
629            sign_keys=len(self._server.keys) > 0, responder_ids=responder_ids)
630        initiator.log.debug('Sending server-auth including responder ids')
631        await initiator.send(server_auth)
632
633    async def handshake_responder(self, client_hello: ClientHelloMessage) -> None:
634        """
635        Disconnected
636        MessageError
637        MessageFlowError
638        SlotsFullError
639        DowngradeError
640        ServerKeyError
641        """
642        path, responder = self.path, self.client
643        assert path is not None
644        assert responder is not None
645
646        # Set key on client
647        responder.set_client_key(
648            ResponderPublicSessionKey(client_hello.client_public_key))
649
650        # Receive client-auth
651        client_auth = await responder.receive()
652        if not isinstance(client_auth, ClientAuthMessage):
653            error = "Expected 'client-auth', got '{}'"
654            raise MessageFlowError(error.format(client_auth.type))
655
656        # Handle client-auth
657        self._handle_client_auth(client_auth)
658
659        # Authenticated
660        id_ = path.add_responder(responder)
661
662        # Send new-responder message if initiator is present
663        initiator = None  # type: Optional[PathClient]
664        try:
665            initiator = path.get_initiator()
666        except KeyError:
667            pass
668        else:
669            # Create message and add send coroutine to job queue of the initiator
670            new_responder = NewResponderMessage.create(id_)
671            initiator.log.debug('Enqueueing new-responder message')
672            await initiator.jobs.enqueue(initiator.send(new_responder))
673
674        # Send server-auth
675        server_auth = ServerAuthMessage.create(
676            ResponderAddress(responder.id), responder.cookie_in,
677            sign_keys=len(self._server.keys) > 0,
678            initiator_connected=initiator is not None)
679        responder.log.debug('Sending server-auth without responder ids')
680        await responder.send(server_auth)
681
682    async def initiator_receive_loop(self) -> NoReturn:
683        path, initiator = self.path, self.client
684        assert path is not None
685        assert initiator is not None
686        while True:
687            # Receive relay message or drop-responder
688            message = await initiator.receive()
689
690            # Relay
691            if isinstance(message, RelayMessage):
692                # Lookup responder
693                responder = None  # type: Optional[PathClient]
694                try:
695                    responder_id = ResponderAddress(message.destination)
696                    responder = path.get_responder(responder_id)
697                except KeyError:
698                    pass
699                # Send to responder
700                await self.relay_message(
701                    responder, ClientAddress(message.destination), message)
702            # Drop-responder
703            elif isinstance(message, DropResponderMessage):
704                # Lookup responder
705                try:
706                    responder = path.get_responder(message.responder_id)
707                except KeyError:
708                    log_message = 'Responder {} already dropped, nothing to do'
709                    path.log.debug(log_message, message.responder_id)
710                else:
711                    # Drop responder using its job queue
712                    path.log.debug(
713                        'Dropping responder {}, reason: {}', responder, message.reason)
714                    responder.log.debug(
715                        'Dropping (requested by initiator), reason: {}', message.reason)
716                    self._drop_client(responder, CloseCode(message.reason))
717            else:
718                error = "Expected relay message or 'drop-responder', got '{}'"
719                raise MessageFlowError(error.format(message.type))
720
721    async def responder_receive_loop(self) -> NoReturn:
722        path, responder = self.path, self.client
723        assert path is not None
724        assert responder is not None
725        while True:
726            # Receive relay message
727            message = await responder.receive()
728
729            # Relay
730            if isinstance(message, RelayMessage):
731                # Lookup initiator
732                initiator = None  # type: Optional[PathClient]
733                try:
734                    initiator = path.get_initiator()
735                except KeyError:
736                    pass
737                # Send to initiator
738                await self.relay_message(initiator, INITIATOR_ADDRESS, message)
739            else:
740                error = "Expected relay message, got '{}'"
741                raise MessageFlowError(error.format(message.type))
742
743    async def relay_message(
744            self,
745            destination: Optional[PathClient],
746            destination_id: ClientAddress,
747            message: RelayMessage,
748    ) -> None:
749        source = self.client
750        assert source is not None
751
752        # Prepare message
753        source.log.debug('Packing relay message')
754        message_id = MessageId(message.pack(source)[COOKIE_LENGTH:NONCE_LENGTH])
755
756        async def send_error_message() -> None:
757            assert source is not None
758            # Create message and add send coroutine to job queue of the source
759            error = SendErrorMessage.create(ClientAddress(source.id), message_id)
760            source.log.info('Relaying failed, enqueuing send-error')
761            await source.jobs.enqueue(source.send(error))
762
763        # Destination not connected? Send 'send-error' to source
764        if destination is None:
765            error_message = ('Cannot relay message, no connection for '
766                             'destination id 0x{:02x}')
767            source.log.info(error_message, destination_id)
768            await send_error_message()
769            return
770
771        # Add send task to job queue of the destination
772        task = self._loop.create_task(destination.send(message))
773        destination.log.debug('Enqueueing relayed message from 0x{:02x}', source.id)
774        await destination.jobs.enqueue(task)
775
776        # noinspection PyBroadException
777        try:
778            # Wait for send task to complete
779            await asyncio.wait_for(task, RELAY_TIMEOUT, loop=self._loop)
780        except asyncio.TimeoutError:
781            # Timed out, send 'send-error' to source
782            log_message = 'Sending relayed message to 0x{:02x} timed out'
783            source.log.info(log_message, destination_id)
784            await send_error_message()
785        except Exception as exc:
786            # Handle cancellation of the client
787            if isinstance(exc, asyncio.CancelledError) and source.tasks.have_result:
788                raise
789
790            # An exception has been triggered while sending the message.
791            # Note: We don't care about the actual exception as the job
792            #       queue runner will also trigger that exception on the
793            #       destination client's handler who will log what happened.
794            log_message = 'Sending relayed message failed, receiver 0x{:02x} is gone'
795            source.log.info(log_message, destination_id)
796            await send_error_message()
797        else:
798            source.log.debug('Sending relayed message to 0x{:02x} successful',
799                             destination.id)
800
801    async def keep_alive_loop(self) -> NoReturn:
802        """
803        Disconnected
804        PingTimeoutError
805        """
806        client = self.client
807        assert client is not None
808        while True:
809            # Wait
810            # noinspection PyTypeChecker
811            await asyncio.sleep(client.keep_alive_interval, loop=self._loop)
812
813            # Send ping and wait for pong
814            client.log.debug('Ping')
815            pong_future = await client.ping()
816            try:
817                await asyncio.wait_for(
818                    client.wait_pong(pong_future), client.keep_alive_timeout,
819                    loop=self._loop)
820            except asyncio.TimeoutError:
821                client.log.debug('Ping timed out')
822                raise PingTimeoutError(str(client))
823            else:
824                client.log.debug('Pong')
825                client.keep_alive_pings += 1
826
827    def _handle_client_auth(self, client_auth: ClientAuthMessage) -> None:
828        """
829        MessageError
830        DowngradeError
831        ServerKeyError
832        """
833        client = self.client
834        assert client is not None
835
836        # Validate cookie and ensure no sub-protocol downgrade took place
837        self._validate_cookie(client_auth.server_cookie, client.cookie_out)
838        self._validate_subprotocol(client_auth.subprotocols)
839
840        # Set the keep alive interval (if any)
841        if client_auth.ping_interval is not None:
842            client.log.debug(
843                'Setting keep-alive interval to {}', client_auth.ping_interval)
844            client.keep_alive_interval = client_auth.ping_interval
845
846        # Set the public permanent key the client wants to use (or fallback to primary)
847        server_keys_count = len(self._server.keys)
848        if client_auth.server_key is not None:
849            # No permanent key pair?
850            if server_keys_count == 0:
851                raise ServerKeyError('Server does not have a permanent public key')
852
853            # Find the key instance
854            server_key = self._server.keys.get(client_auth.server_key)
855            if server_key is None:
856                raise ServerKeyError(
857                    'Server does not have the requested permanent public key')
858
859            # Set the key instance on the client
860            client.server_permanent_key = server_key
861        elif server_keys_count > 0:
862            # Use primary permanent key
863            client.server_permanent_key = next(iter(self._server.keys.values()))
864
865    def _validate_cookie(
866            self,
867            expected_cookie: ServerCookie,
868            actual_cookie: ServerCookie,
869    ) -> None:
870        """
871        MessageError
872        """
873        client = self.client
874        assert client is not None
875        client.log.debug('Validating cookie')
876        if not util.consteq(expected_cookie, actual_cookie):
877            raise MessageError('Cookies do not match')
878
879    def _validate_subprotocol(
880            self,
881            client_subprotocols: ListOrTuple[ChosenSubProtocol],
882    ) -> None:
883        """
884        MessageError
885        DowngradeError
886        """
887        client = self.client
888        assert client is not None
889        client.log.debug(
890            'Checking for subprotocol downgrade, client: {}, server: {}',
891            client_subprotocols, self._server.subprotocols)
892        chosen = self.select_subprotocol(
893            cast(Sequence[Subprotocol], client_subprotocols),
894            cast(Sequence[Subprotocol], self._server.subprotocols))
895        if chosen != self.subprotocol.value:
896            raise DowngradeError('Subprotocol downgrade detected')
897
898    def _drop_client(self, client: PathClient, code: CloseCode) -> None:
899        """
900        Mark the client as closed, schedule the closing procedure on
901        the client's job queue and remove it from the path.
902
903        .. important:: This should only be called by clients dropping
904                       another client or when the server is closing.
905
906        Arguments:
907            - `client`: The client to be dropped.
908            - `close`: The close code.
909
910        Raises:
911            - :exc:`KeyError` in case the client is not attached to the
912              path.
913            - :exc:`ValueError` in case the path has been detached.
914        """
915        # Drop the client
916        client.drop(code)
917
918        # Remove the client from the path
919        path = self.path
920        assert path is not None
921        path.remove_client(client)
922
923
924class Paths:
925    __slots__ = ('_log', 'number', 'paths')
926
927    def __init__(self) -> None:
928        self._log = util.get_logger('paths')
929        self.number = 0
930        self.paths = {}  # type: Dict[InitiatorPublicPermanentKey, Path]
931
932    def get(self, initiator_key: InitiatorPublicPermanentKey) -> Path:
933        if self.paths.get(initiator_key) is None:
934            self.number += 1
935            self.paths[initiator_key] = Path(initiator_key, self.number, attached=True)
936            self._log.debug('Created new path: {}', self.number)
937        return self.paths[initiator_key]
938
939    def clean(self, path: Path) -> None:
940        if path.empty:
941            path.attached = False
942            try:
943                del self.paths[path.initiator_key]
944            except KeyError:
945                self._log.error('Path {} has already been removed', path.number)
946            else:
947                self._log.debug('Removed empty path: {}', path.number)
948            path.clear()
949
950
951class Server:
952    subprotocols = [
953        SubProtocol.saltyrtc_v1.value
954    ]  # type: ClassVar[Sequence[SubProtocol]]
955
956    def __init__(
957            self,
958            keys: Optional[Sequence[ServerSecretPermanentKey]],
959            paths: Paths,
960            loop: Optional[asyncio.AbstractEventLoop] = None,
961    ) -> None:
962        self._log = util.get_logger('server')
963        self._loop = asyncio.get_event_loop() if loop is None else loop
964
965        # Protocol class
966        self.protocol_class = ServerProtocol  # type: Type[ServerProtocol]
967
968        # WebSocket server instance
969        self._server = None  # type: Optional[websockets.server.WebSocketServer]
970
971        # Validate & store keys
972        if keys is None:
973            keys = []
974        if len(keys) != len({key.pk for key in keys}):
975            raise ServerKeyError('Repeated permanent keys')
976        self.keys = OrderedDict(
977            ((ServerPublicPermanentKey(key.pk), key) for key in keys))  # type: Keys
978
979        # Store paths
980        self.paths = paths
981
982        # Store server protocols and closing task
983        self.protocols = set()  # type: Set[ServerProtocol]
984        self._close_task = None  # type: Optional[asyncio.Task[None]]
985
986        # Event Registry
987        self._events = EventRegistry()
988
989    @property
990    def server(self) -> websockets.server.WebSocketServer:
991        assert self._server is not None
992        return self._server
993
994    @server.setter
995    def server(self, server: websockets.server.WebSocketServer) -> None:
996        self._server = server
997        self._log.debug('Server instance: {}', server)
998
999    async def handler(
1000            self,
1001            connection: websockets.WebSocketServerProtocol,
1002            ws_path: str,
1003    ) -> None:
1004        # Closing? Drop immediately
1005        if self._close_task is not None:
1006            await connection.close(CloseCode.going_away.value)
1007            return
1008
1009        # Convert sub-protocol
1010        subprotocol = None  # type: Optional[SubProtocol]
1011        try:
1012            subprotocol = SubProtocol(connection.subprotocol)
1013        except ValueError:
1014            pass
1015
1016        # Determine ServerProtocol instance by selected sub-protocol
1017        if subprotocol != SubProtocol.saltyrtc_v1:
1018            self._log.notice('Could not negotiate a sub-protocol, dropping client')
1019            # We need to close the connection manually as the client may choose
1020            # to ignore
1021            await connection.close(code=CloseCode.subprotocol_error.value)
1022            self.notify_disconnected(
1023                None, DisconnectedData(CloseCode.subprotocol_error.value))
1024        else:
1025            assert subprotocol is not None
1026            protocol = self.protocol_class(
1027                self, subprotocol, connection, ws_path, loop=self._loop)
1028            await protocol.handler_task
1029
1030    def register(self, protocol: ServerProtocol) -> None:
1031        self.protocols.add(protocol)
1032        self._log.debug('Protocol registered: {}', protocol)
1033
1034    def unregister(self, protocol: ServerProtocol) -> None:
1035        self.protocols.remove(protocol)
1036        self._log.debug('Protocol unregistered: {}', protocol)
1037
1038    def register_event_callback(self, event: Event, callback: EventCallback) -> None:
1039        """
1040        Register a new event callback.
1041        """
1042        self._events.register(event, callback)
1043
1044    def notify_initiator_connected(self, path: PathHex) -> None:
1045        self._raise_event(Event.initiator_connected, path, None)
1046
1047    def notify_responder_connected(self, path: PathHex) -> None:
1048        self._raise_event(Event.responder_connected, path, None)
1049
1050    def notify_disconnected(
1051            self,
1052            path: Optional[PathHex],
1053            data: DisconnectedData,
1054    ) -> None:
1055        self._raise_event(Event.disconnected, path, data)
1056
1057    def _raise_event(
1058            self,
1059            event: Event,
1060            path: Optional[PathHex],
1061            data: EventData,
1062    ) -> None:
1063        """
1064        Raise an event and invoke all registered event callbacks.
1065
1066        Arguments:
1067            - `event`: Event to be raised.
1068            - `path`: Associated path in hexadecimal representation or
1069              `None` if not available.
1070            - `data`: Additional data for the event as explained for
1071              :class:`EventRegistry`.
1072        """
1073        for callback in self._events.get_callbacks(event):
1074            coroutine = callback(event, path, data)
1075            log_handler = functools.partial(
1076                self._log.exception, 'Unhandled exception in event handler:')
1077            # noinspection PyTypeChecker
1078            self._loop.create_task(util.log_exception(coroutine, log_handler))
1079
1080    def close(self) -> None:
1081        """
1082        Close open connections and the server.
1083        """
1084        if self._close_task is None:
1085            log_handler = functools.partial(
1086                self._log.exception, 'Exception while closing:')
1087            # noinspection PyTypeChecker
1088            self._close_task = self._loop.create_task(
1089                util.log_exception(self._close_after_all_protocols_closed(), log_handler))
1090
1091    async def wait_closed(self) -> None:
1092        """
1093        Wait until all connections and the server itself has been
1094        closed.
1095        """
1096        await self.server.wait_closed()
1097
1098    async def _close_after_all_protocols_closed(
1099            self,
1100            timeout: Optional[float] = None,
1101    ) -> None:
1102        # Schedule closing all protocols
1103        self._log.info('Closing protocols')
1104        if len(self.protocols) > 0:
1105            async def _close_and_wait() -> None:
1106                # Wait until all connections have been scheduled to be closed
1107                for protocol in self.protocols:
1108                    protocol.close(CloseCode.going_away)
1109
1110                # Wait until all protocols have returned
1111                handler_tasks = [protocol.handler_task for protocol in self.protocols]
1112                await asyncio.gather(*handler_tasks, loop=self._loop)
1113
1114            await asyncio.wait_for(_close_and_wait(), timeout, loop=self._loop)
1115
1116        # Now we can close the server
1117        self._log.info('Closing server')
1118        self.server.close()
1119