1"""
2:mod:`websockets.server` defines the WebSocket server APIs.
3
4"""
5
6import asyncio
7import collections.abc
8import email.utils
9import functools
10import http
11import logging
12import socket
13import sys
14import warnings
15from types import TracebackType
16from typing import (
17    Any,
18    Awaitable,
19    Callable,
20    Generator,
21    List,
22    Optional,
23    Sequence,
24    Set,
25    Tuple,
26    Type,
27    Union,
28    cast,
29)
30
31from .exceptions import (
32    AbortHandshake,
33    InvalidHandshake,
34    InvalidHeader,
35    InvalidMessage,
36    InvalidOrigin,
37    InvalidUpgrade,
38    NegotiationError,
39)
40from .extensions.base import Extension, ServerExtensionFactory
41from .extensions.permessage_deflate import ServerPerMessageDeflateFactory
42from .handshake import build_response, check_request
43from .headers import build_extension, parse_extension, parse_subprotocol
44from .http import USER_AGENT, Headers, HeadersLike, MultipleValuesError, read_request
45from .protocol import WebSocketCommonProtocol
46from .typing import ExtensionHeader, Origin, Subprotocol
47
48
49__all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"]
50
51logger = logging.getLogger(__name__)
52
53
54HeadersLikeOrCallable = Union[HeadersLike, Callable[[str, Headers], HeadersLike]]
55
56HTTPResponse = Tuple[http.HTTPStatus, HeadersLike, bytes]
57
58
59class WebSocketServerProtocol(WebSocketCommonProtocol):
60    """
61    :class:`~asyncio.Protocol` subclass implementing a WebSocket server.
62
63    This class inherits most of its methods from
64    :class:`~websockets.protocol.WebSocketCommonProtocol`.
65
66    For the sake of simplicity, it doesn't rely on a full HTTP implementation.
67    Its support for HTTP responses is very limited.
68
69    """
70
71    is_client = False
72    side = "server"
73
74    def __init__(
75        self,
76        ws_handler: Callable[["WebSocketServerProtocol", str], Awaitable[Any]],
77        ws_server: "WebSocketServer",
78        *,
79        origins: Optional[Sequence[Optional[Origin]]] = None,
80        extensions: Optional[Sequence[ServerExtensionFactory]] = None,
81        subprotocols: Optional[Sequence[Subprotocol]] = None,
82        extra_headers: Optional[HeadersLikeOrCallable] = None,
83        process_request: Optional[
84            Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]]
85        ] = None,
86        select_subprotocol: Optional[
87            Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol]
88        ] = None,
89        **kwargs: Any,
90    ) -> None:
91        # For backwards compatibility with 6.0 or earlier.
92        if origins is not None and "" in origins:
93            warnings.warn("use None instead of '' in origins", DeprecationWarning)
94            origins = [None if origin == "" else origin for origin in origins]
95        self.ws_handler = ws_handler
96        self.ws_server = ws_server
97        self.origins = origins
98        self.available_extensions = extensions
99        self.available_subprotocols = subprotocols
100        self.extra_headers = extra_headers
101        self._process_request = process_request
102        self._select_subprotocol = select_subprotocol
103        super().__init__(**kwargs)
104
105    def connection_made(self, transport: asyncio.BaseTransport) -> None:
106        """
107        Register connection and initialize a task to handle it.
108
109        """
110        super().connection_made(transport)
111        # Register the connection with the server before creating the handler
112        # task. Registering at the beginning of the handler coroutine would
113        # create a race condition between the creation of the task, which
114        # schedules its execution, and the moment the handler starts running.
115        self.ws_server.register(self)
116        self.handler_task = self.loop.create_task(self.handler())
117
118    async def handler(self) -> None:
119        """
120        Handle the lifecycle of a WebSocket connection.
121
122        Since this method doesn't have a caller able to handle exceptions, it
123        attemps to log relevant ones and guarantees that the TCP connection is
124        closed before exiting.
125
126        """
127        try:
128
129            try:
130                path = await self.handshake(
131                    origins=self.origins,
132                    available_extensions=self.available_extensions,
133                    available_subprotocols=self.available_subprotocols,
134                    extra_headers=self.extra_headers,
135                )
136            except ConnectionError:
137                logger.debug("Connection error in opening handshake", exc_info=True)
138                raise
139            except Exception as exc:
140                if isinstance(exc, AbortHandshake):
141                    status, headers, body = exc.status, exc.headers, exc.body
142                elif isinstance(exc, InvalidOrigin):
143                    logger.debug("Invalid origin", exc_info=True)
144                    status, headers, body = (
145                        http.HTTPStatus.FORBIDDEN,
146                        Headers(),
147                        f"Failed to open a WebSocket connection: {exc}.\n".encode(),
148                    )
149                elif isinstance(exc, InvalidUpgrade):
150                    logger.debug("Invalid upgrade", exc_info=True)
151                    status, headers, body = (
152                        http.HTTPStatus.UPGRADE_REQUIRED,
153                        Headers([("Upgrade", "websocket")]),
154                        (
155                            f"Failed to open a WebSocket connection: {exc}.\n"
156                            f"\n"
157                            f"You cannot access a WebSocket server directly "
158                            f"with a browser. You need a WebSocket client.\n"
159                        ).encode(),
160                    )
161                elif isinstance(exc, InvalidHandshake):
162                    logger.debug("Invalid handshake", exc_info=True)
163                    status, headers, body = (
164                        http.HTTPStatus.BAD_REQUEST,
165                        Headers(),
166                        f"Failed to open a WebSocket connection: {exc}.\n".encode(),
167                    )
168                else:
169                    logger.warning("Error in opening handshake", exc_info=True)
170                    status, headers, body = (
171                        http.HTTPStatus.INTERNAL_SERVER_ERROR,
172                        Headers(),
173                        (
174                            b"Failed to open a WebSocket connection.\n"
175                            b"See server log for more information.\n"
176                        ),
177                    )
178
179                headers.setdefault("Date", email.utils.formatdate(usegmt=True))
180                headers.setdefault("Server", USER_AGENT)
181                headers.setdefault("Content-Length", str(len(body)))
182                headers.setdefault("Content-Type", "text/plain")
183                headers.setdefault("Connection", "close")
184
185                self.write_http_response(status, headers, body)
186                self.fail_connection()
187                await self.wait_closed()
188                return
189
190            try:
191                await self.ws_handler(self, path)
192            except Exception:
193                logger.error("Error in connection handler", exc_info=True)
194                if not self.closed:
195                    self.fail_connection(1011)
196                raise
197
198            try:
199                await self.close()
200            except ConnectionError:
201                logger.debug("Connection error in closing handshake", exc_info=True)
202                raise
203            except Exception:
204                logger.warning("Error in closing handshake", exc_info=True)
205                raise
206
207        except Exception:
208            # Last-ditch attempt to avoid leaking connections on errors.
209            try:
210                self.transport.close()
211            except Exception:  # pragma: no cover
212                pass
213
214        finally:
215            # Unregister the connection with the server when the handler task
216            # terminates. Registration is tied to the lifecycle of the handler
217            # task because the server waits for tasks attached to registered
218            # connections before terminating.
219            self.ws_server.unregister(self)
220
221    async def read_http_request(self) -> Tuple[str, Headers]:
222        """
223        Read request line and headers from the HTTP request.
224
225        If the request contains a body, it may be read from ``self.reader``
226        after this coroutine returns.
227
228        :raises ~websockets.exceptions.InvalidMessage: if the HTTP message is
229            malformed or isn't an HTTP/1.1 GET request
230
231        """
232        try:
233            path, headers = await read_request(self.reader)
234        except Exception as exc:
235            raise InvalidMessage("did not receive a valid HTTP request") from exc
236
237        logger.debug("%s < GET %s HTTP/1.1", self.side, path)
238        logger.debug("%s < %r", self.side, headers)
239
240        self.path = path
241        self.request_headers = headers
242
243        return path, headers
244
245    def write_http_response(
246        self, status: http.HTTPStatus, headers: Headers, body: Optional[bytes] = None
247    ) -> None:
248        """
249        Write status line and headers to the HTTP response.
250
251        This coroutine is also able to write a response body.
252
253        """
254        self.response_headers = headers
255
256        logger.debug("%s > HTTP/1.1 %d %s", self.side, status.value, status.phrase)
257        logger.debug("%s > %r", self.side, headers)
258
259        # Since the status line and headers only contain ASCII characters,
260        # we can keep this simple.
261        response = f"HTTP/1.1 {status.value} {status.phrase}\r\n"
262        response += str(headers)
263
264        self.transport.write(response.encode())
265
266        if body is not None:
267            logger.debug("%s > body (%d bytes)", self.side, len(body))
268            self.transport.write(body)
269
270    async def process_request(
271        self, path: str, request_headers: Headers
272    ) -> Optional[HTTPResponse]:
273        """
274        Intercept the HTTP request and return an HTTP response if appropriate.
275
276        If ``process_request`` returns ``None``, the WebSocket handshake
277        continues. If it returns 3-uple containing a status code, response
278        headers and a response body, that HTTP response is sent and the
279        connection is closed. In that case:
280
281        * The HTTP status must be a :class:`~http.HTTPStatus`.
282        * HTTP headers must be a :class:`~websockets.http.Headers` instance, a
283          :class:`~collections.abc.Mapping`, or an iterable of ``(name,
284          value)`` pairs.
285        * The HTTP response body must be :class:`bytes`. It may be empty.
286
287        This coroutine may be overridden in a :class:`WebSocketServerProtocol`
288        subclass, for example:
289
290        * to return a HTTP 200 OK response on a given path; then a load
291          balancer can use this path for a health check;
292        * to authenticate the request and return a HTTP 401 Unauthorized or a
293          HTTP 403 Forbidden when authentication fails.
294
295        Instead of subclassing, it is possible to override this method by
296        passing a ``process_request`` argument to the :func:`serve` function
297        or the :class:`WebSocketServerProtocol` constructor. This is
298        equivalent, except ``process_request`` won't have access to the
299        protocol instance, so it can't store information for later use.
300
301        ``process_request`` is expected to complete quickly. If it may run for
302        a long time, then it should await :meth:`wait_closed` and exit if
303        :meth:`wait_closed` completes, or else it could prevent the server
304        from shutting down.
305
306        :param path: request path, including optional query string
307        :param request_headers: request headers
308
309        """
310        if self._process_request is not None:
311            response = self._process_request(path, request_headers)
312            if isinstance(response, Awaitable):
313                return await response
314            else:
315                # For backwards compatibility with 7.0.
316                warnings.warn(
317                    "declare process_request as a coroutine", DeprecationWarning
318                )
319                return response  # type: ignore
320        return None
321
322    @staticmethod
323    def process_origin(
324        headers: Headers, origins: Optional[Sequence[Optional[Origin]]] = None
325    ) -> Optional[Origin]:
326        """
327        Handle the Origin HTTP request header.
328
329        :param headers: request headers
330        :param origins: optional list of acceptable origins
331        :raises ~websockets.exceptions.InvalidOrigin: if the origin isn't
332            acceptable
333
334        """
335        # "The user agent MUST NOT include more than one Origin header field"
336        # per https://tools.ietf.org/html/rfc6454#section-7.3.
337        try:
338            origin = cast(Origin, headers.get("Origin"))
339        except MultipleValuesError:
340            raise InvalidHeader("Origin", "more than one Origin header found")
341        if origins is not None:
342            if origin not in origins:
343                raise InvalidOrigin(origin)
344        return origin
345
346    @staticmethod
347    def process_extensions(
348        headers: Headers,
349        available_extensions: Optional[Sequence[ServerExtensionFactory]],
350    ) -> Tuple[Optional[str], List[Extension]]:
351        """
352        Handle the Sec-WebSocket-Extensions HTTP request header.
353
354        Accept or reject each extension proposed in the client request.
355        Negotiate parameters for accepted extensions.
356
357        Return the Sec-WebSocket-Extensions HTTP response header and the list
358        of accepted extensions.
359
360        :rfc:`6455` leaves the rules up to the specification of each
361        :extension.
362
363        To provide this level of flexibility, for each extension proposed by
364        the client, we check for a match with each extension available in the
365        server configuration. If no match is found, the extension is ignored.
366
367        If several variants of the same extension are proposed by the client,
368        it may be accepted severel times, which won't make sense in general.
369        Extensions must implement their own requirements. For this purpose,
370        the list of previously accepted extensions is provided.
371
372        This process doesn't allow the server to reorder extensions. It can
373        only select a subset of the extensions proposed by the client.
374
375        Other requirements, for example related to mandatory extensions or the
376        order of extensions, may be implemented by overriding this method.
377
378        :param headers: request headers
379        :param extensions: optional list of supported extensions
380        :raises ~websockets.exceptions.InvalidHandshake: to abort the
381            handshake with an HTTP 400 error code
382
383        """
384        response_header_value: Optional[str] = None
385
386        extension_headers: List[ExtensionHeader] = []
387        accepted_extensions: List[Extension] = []
388
389        header_values = headers.get_all("Sec-WebSocket-Extensions")
390
391        if header_values and available_extensions:
392
393            parsed_header_values: List[ExtensionHeader] = sum(
394                [parse_extension(header_value) for header_value in header_values], []
395            )
396
397            for name, request_params in parsed_header_values:
398
399                for ext_factory in available_extensions:
400
401                    # Skip non-matching extensions based on their name.
402                    if ext_factory.name != name:
403                        continue
404
405                    # Skip non-matching extensions based on their params.
406                    try:
407                        response_params, extension = ext_factory.process_request_params(
408                            request_params, accepted_extensions
409                        )
410                    except NegotiationError:
411                        continue
412
413                    # Add matching extension to the final list.
414                    extension_headers.append((name, response_params))
415                    accepted_extensions.append(extension)
416
417                    # Break out of the loop once we have a match.
418                    break
419
420                # If we didn't break from the loop, no extension in our list
421                # matched what the client sent. The extension is declined.
422
423        # Serialize extension header.
424        if extension_headers:
425            response_header_value = build_extension(extension_headers)
426
427        return response_header_value, accepted_extensions
428
429    # Not @staticmethod because it calls self.select_subprotocol()
430    def process_subprotocol(
431        self, headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]]
432    ) -> Optional[Subprotocol]:
433        """
434        Handle the Sec-WebSocket-Protocol HTTP request header.
435
436        Return Sec-WebSocket-Protocol HTTP response header, which is the same
437        as the selected subprotocol.
438
439        :param headers: request headers
440        :param available_subprotocols: optional list of supported subprotocols
441        :raises ~websockets.exceptions.InvalidHandshake: to abort the
442            handshake with an HTTP 400 error code
443
444        """
445        subprotocol: Optional[Subprotocol] = None
446
447        header_values = headers.get_all("Sec-WebSocket-Protocol")
448
449        if header_values and available_subprotocols:
450
451            parsed_header_values: List[Subprotocol] = sum(
452                [parse_subprotocol(header_value) for header_value in header_values], []
453            )
454
455            subprotocol = self.select_subprotocol(
456                parsed_header_values, available_subprotocols
457            )
458
459        return subprotocol
460
461    def select_subprotocol(
462        self,
463        client_subprotocols: Sequence[Subprotocol],
464        server_subprotocols: Sequence[Subprotocol],
465    ) -> Optional[Subprotocol]:
466        """
467        Pick a subprotocol among those offered by the client.
468
469        If several subprotocols are supported by the client and the server,
470        the default implementation selects the preferred subprotocols by
471        giving equal value to the priorities of the client and the server.
472
473        If no subprotocol is supported by the client and the server, it
474        proceeds without a subprotocol.
475
476        This is unlikely to be the most useful implementation in practice, as
477        many servers providing a subprotocol will require that the client uses
478        that subprotocol. Such rules can be implemented in a subclass.
479
480        Instead of subclassing, it is possible to override this method by
481        passing a ``select_subprotocol`` argument to the :func:`serve`
482        function or the :class:`WebSocketServerProtocol` constructor
483
484        :param client_subprotocols: list of subprotocols offered by the client
485        :param server_subprotocols: list of subprotocols available on the server
486
487        """
488        if self._select_subprotocol is not None:
489            return self._select_subprotocol(client_subprotocols, server_subprotocols)
490
491        subprotocols = set(client_subprotocols) & set(server_subprotocols)
492        if not subprotocols:
493            return None
494        priority = lambda p: (
495            client_subprotocols.index(p) + server_subprotocols.index(p)
496        )
497        return sorted(subprotocols, key=priority)[0]
498
499    async def handshake(
500        self,
501        origins: Optional[Sequence[Optional[Origin]]] = None,
502        available_extensions: Optional[Sequence[ServerExtensionFactory]] = None,
503        available_subprotocols: Optional[Sequence[Subprotocol]] = None,
504        extra_headers: Optional[HeadersLikeOrCallable] = None,
505    ) -> str:
506        """
507        Perform the server side of the opening handshake.
508
509        Return the path of the URI of the request.
510
511        :param origins: list of acceptable values of the Origin HTTP header;
512            include ``None`` if the lack of an origin is acceptable
513        :param available_extensions: list of supported extensions in the order
514            in which they should be used
515        :param available_subprotocols: list of supported subprotocols in order
516            of decreasing preference
517        :param extra_headers: sets additional HTTP response headers when the
518            handshake succeeds; it can be a :class:`~websockets.http.Headers`
519            instance, a :class:`~collections.abc.Mapping`, an iterable of
520            ``(name, value)`` pairs, or a callable taking the request path and
521            headers in arguments and returning one of the above.
522        :raises ~websockets.exceptions.InvalidHandshake: if the handshake
523            fails
524
525        """
526        path, request_headers = await self.read_http_request()
527
528        # Hook for customizing request handling, for example checking
529        # authentication or treating some paths as plain HTTP endpoints.
530        early_response_awaitable = self.process_request(path, request_headers)
531        if isinstance(early_response_awaitable, Awaitable):
532            early_response = await early_response_awaitable
533        else:
534            # For backwards compatibility with 7.0.
535            warnings.warn("declare process_request as a coroutine", DeprecationWarning)
536            early_response = early_response_awaitable  # type: ignore
537
538        # Change the response to a 503 error if the server is shutting down.
539        if not self.ws_server.is_serving():
540            early_response = (
541                http.HTTPStatus.SERVICE_UNAVAILABLE,
542                [],
543                b"Server is shutting down.\n",
544            )
545
546        if early_response is not None:
547            raise AbortHandshake(*early_response)
548
549        key = check_request(request_headers)
550
551        self.origin = self.process_origin(request_headers, origins)
552
553        extensions_header, self.extensions = self.process_extensions(
554            request_headers, available_extensions
555        )
556
557        protocol_header = self.subprotocol = self.process_subprotocol(
558            request_headers, available_subprotocols
559        )
560
561        response_headers = Headers()
562
563        build_response(response_headers, key)
564
565        if extensions_header is not None:
566            response_headers["Sec-WebSocket-Extensions"] = extensions_header
567
568        if protocol_header is not None:
569            response_headers["Sec-WebSocket-Protocol"] = protocol_header
570
571        if callable(extra_headers):
572            extra_headers = extra_headers(path, self.request_headers)
573        if extra_headers is not None:
574            if isinstance(extra_headers, Headers):
575                extra_headers = extra_headers.raw_items()
576            elif isinstance(extra_headers, collections.abc.Mapping):
577                extra_headers = extra_headers.items()
578            for name, value in extra_headers:
579                response_headers[name] = value
580
581        response_headers.setdefault("Date", email.utils.formatdate(usegmt=True))
582        response_headers.setdefault("Server", USER_AGENT)
583
584        self.write_http_response(http.HTTPStatus.SWITCHING_PROTOCOLS, response_headers)
585
586        self.connection_open()
587
588        return path
589
590
591class WebSocketServer:
592    """
593    WebSocket server returned by :func:`~websockets.server.serve`.
594
595    This class provides the same interface as
596    :class:`~asyncio.AbstractServer`, namely the
597    :meth:`~asyncio.AbstractServer.close` and
598    :meth:`~asyncio.AbstractServer.wait_closed` methods.
599
600    It keeps track of WebSocket connections in order to close them properly
601    when shutting down.
602
603    Instances of this class store a reference to the :class:`~asyncio.Server`
604    object returned by :meth:`~asyncio.loop.create_server` rather than inherit
605    from :class:`~asyncio.Server` in part because
606    :meth:`~asyncio.loop.create_server` doesn't support passing a custom
607    :class:`~asyncio.Server` class.
608
609    """
610
611    def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
612        # Store a reference to loop to avoid relying on self.server._loop.
613        self.loop = loop
614
615        # Keep track of active connections.
616        self.websockets: Set[WebSocketServerProtocol] = set()
617
618        # Task responsible for closing the server and terminating connections.
619        self.close_task: Optional[asyncio.Task[None]] = None
620
621        # Completed when the server is closed and connections are terminated.
622        self.closed_waiter: asyncio.Future[None] = loop.create_future()
623
624    def wrap(self, server: asyncio.AbstractServer) -> None:
625        """
626        Attach to a given :class:`~asyncio.Server`.
627
628        Since :meth:`~asyncio.loop.create_server` doesn't support injecting a
629        custom ``Server`` class, the easiest solution that doesn't rely on
630        private :mod:`asyncio` APIs is to:
631
632        - instantiate a :class:`WebSocketServer`
633        - give the protocol factory a reference to that instance
634        - call :meth:`~asyncio.loop.create_server` with the factory
635        - attach the resulting :class:`~asyncio.Server` with this method
636
637        """
638        self.server = server
639
640    def register(self, protocol: WebSocketServerProtocol) -> None:
641        """
642        Register a connection with this server.
643
644        """
645        self.websockets.add(protocol)
646
647    def unregister(self, protocol: WebSocketServerProtocol) -> None:
648        """
649        Unregister a connection with this server.
650
651        """
652        self.websockets.remove(protocol)
653
654    def is_serving(self) -> bool:
655        """
656        Tell whether the server is accepting new connections or shutting down.
657
658        """
659        try:
660            # Python ≥ 3.7
661            return self.server.is_serving()
662        except AttributeError:  # pragma: no cover
663            # Python < 3.7
664            return self.server.sockets is not None
665
666    def close(self) -> None:
667        """
668        Close the server.
669
670        This method:
671
672        * closes the underlying :class:`~asyncio.Server`;
673        * rejects new WebSocket connections with an HTTP 503 (service
674          unavailable) error; this happens when the server accepted the TCP
675          connection but didn't complete the WebSocket opening handshake prior
676          to closing;
677        * closes open WebSocket connections with close code 1001 (going away).
678
679        :meth:`close` is idempotent.
680
681        """
682        if self.close_task is None:
683            self.close_task = self.loop.create_task(self._close())
684
685    async def _close(self) -> None:
686        """
687        Implementation of :meth:`close`.
688
689        This calls :meth:`~asyncio.Server.close` on the underlying
690        :class:`~asyncio.Server` object to stop accepting new connections and
691        then closes open connections with close code 1001.
692
693        """
694        # Stop accepting new connections.
695        self.server.close()
696
697        # Wait until self.server.close() completes.
698        await self.server.wait_closed()
699
700        # Wait until all accepted connections reach connection_made() and call
701        # register(). See https://bugs.python.org/issue34852 for details.
702        await asyncio.sleep(
703            0, loop=self.loop if sys.version_info[:2] < (3, 8) else None
704        )
705
706        # Close OPEN connections with status code 1001. Since the server was
707        # closed, handshake() closes OPENING conections with a HTTP 503 error.
708        # Wait until all connections are closed.
709
710        # asyncio.wait doesn't accept an empty first argument
711        if self.websockets:
712            await asyncio.wait(
713                [websocket.close(1001) for websocket in self.websockets],
714                loop=self.loop if sys.version_info[:2] < (3, 8) else None,
715            )
716
717        # Wait until all connection handlers are complete.
718
719        # asyncio.wait doesn't accept an empty first argument.
720        if self.websockets:
721            await asyncio.wait(
722                [websocket.handler_task for websocket in self.websockets],
723                loop=self.loop if sys.version_info[:2] < (3, 8) else None,
724            )
725
726        # Tell wait_closed() to return.
727        self.closed_waiter.set_result(None)
728
729    async def wait_closed(self) -> None:
730        """
731        Wait until the server is closed.
732
733        When :meth:`wait_closed` returns, all TCP connections are closed and
734        all connection handlers have returned.
735
736        """
737        await asyncio.shield(self.closed_waiter)
738
739    @property
740    def sockets(self) -> Optional[List[socket.socket]]:
741        """
742        List of :class:`~socket.socket` objects the server is listening to.
743
744        ``None`` if the server is closed.
745
746        """
747        return self.server.sockets
748
749
750class Serve:
751    """
752
753    Create, start, and return a WebSocket server on ``host`` and ``port``.
754
755    Whenever a client connects, the server accepts the connection, creates a
756    :class:`WebSocketServerProtocol`, performs the opening handshake, and
757    delegates to the connection handler defined by ``ws_handler``. Once the
758    handler completes, either normally or with an exception, the server
759    performs the closing handshake and closes the connection.
760
761    Awaiting :func:`serve` yields a :class:`WebSocketServer`. This instance
762    provides :meth:`~websockets.server.WebSocketServer.close` and
763    :meth:`~websockets.server.WebSocketServer.wait_closed` methods for
764    terminating the server and cleaning up its resources.
765
766    When a server is closed with :meth:`~WebSocketServer.close`, it closes all
767    connections with close code 1001 (going away). Connections handlers, which
768    are running the ``ws_handler`` coroutine, will receive a
769    :exc:`~websockets.exceptions.ConnectionClosedOK` exception on their
770    current or next interaction with the WebSocket connection.
771
772    :func:`serve` can also be used as an asynchronous context manager. In
773    this case, the server is shut down when exiting the context.
774
775    :func:`serve` is a wrapper around the event loop's
776    :meth:`~asyncio.loop.create_server` method. It creates and starts a
777    :class:`~asyncio.Server` with :meth:`~asyncio.loop.create_server`. Then it
778    wraps the :class:`~asyncio.Server` in a :class:`WebSocketServer`  and
779    returns the :class:`WebSocketServer`.
780
781    The ``ws_handler`` argument is the WebSocket handler. It must be a
782    coroutine accepting two arguments: a :class:`WebSocketServerProtocol` and
783    the request URI.
784
785    The ``host`` and ``port`` arguments, as well as unrecognized keyword
786    arguments, are passed along to :meth:`~asyncio.loop.create_server`.
787
788    For example, you can set the ``ssl`` keyword argument to a
789    :class:`~ssl.SSLContext` to enable TLS.
790
791    The ``create_protocol`` parameter allows customizing the
792    :class:`~asyncio.Protocol` that manages the connection. It should be a
793    callable or class accepting the same arguments as
794    :class:`WebSocketServerProtocol` and returning an instance of
795    :class:`WebSocketServerProtocol` or a subclass. It defaults to
796    :class:`WebSocketServerProtocol`.
797
798    The behavior of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
799    ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` is
800    described in :class:`~websockets.protocol.WebSocketCommonProtocol`.
801
802    :func:`serve` also accepts the following optional arguments:
803
804    * ``compression`` is a shortcut to configure compression extensions;
805      by default it enables the "permessage-deflate" extension; set it to
806      ``None`` to disable compression
807    * ``origins`` defines acceptable Origin HTTP headers; include ``None`` if
808      the lack of an origin is acceptable
809    * ``extensions`` is a list of supported extensions in order of
810      decreasing preference
811    * ``subprotocols`` is a list of supported subprotocols in order of
812      decreasing preference
813    * ``extra_headers`` sets additional HTTP response headers  when the
814      handshake succeeds; it can be a :class:`~websockets.http.Headers`
815      instance, a :class:`~collections.abc.Mapping`, an iterable of ``(name,
816      value)`` pairs, or a callable taking the request path and headers in
817      arguments and returning one of the above
818    * ``process_request`` allows intercepting the HTTP request; it must be a
819      coroutine taking the request path and headers in argument; see
820      :meth:`~WebSocketServerProtocol.process_request` for details
821    * ``select_subprotocol`` allows customizing the logic for selecting a
822      subprotocol; it must be a callable taking the subprotocols offered by
823      the client and available on the server in argument; see
824      :meth:`~WebSocketServerProtocol.select_subprotocol` for details
825
826    Since there's no useful way to propagate exceptions triggered in handlers,
827    they're sent to the ``'websockets.server'`` logger instead. Debugging is
828    much easier if you configure logging to print them::
829
830        import logging
831        logger = logging.getLogger('websockets.server')
832        logger.setLevel(logging.ERROR)
833        logger.addHandler(logging.StreamHandler())
834
835    """
836
837    def __init__(
838        self,
839        ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]],
840        host: Optional[Union[str, Sequence[str]]] = None,
841        port: Optional[int] = None,
842        *,
843        path: Optional[str] = None,
844        create_protocol: Optional[Type[WebSocketServerProtocol]] = None,
845        ping_interval: float = 20,
846        ping_timeout: float = 20,
847        close_timeout: Optional[float] = None,
848        max_size: int = 2 ** 20,
849        max_queue: int = 2 ** 5,
850        read_limit: int = 2 ** 16,
851        write_limit: int = 2 ** 16,
852        loop: Optional[asyncio.AbstractEventLoop] = None,
853        legacy_recv: bool = False,
854        klass: Optional[Type[WebSocketServerProtocol]] = None,
855        timeout: Optional[float] = None,
856        compression: Optional[str] = "deflate",
857        origins: Optional[Sequence[Optional[Origin]]] = None,
858        extensions: Optional[Sequence[ServerExtensionFactory]] = None,
859        subprotocols: Optional[Sequence[Subprotocol]] = None,
860        extra_headers: Optional[HeadersLikeOrCallable] = None,
861        process_request: Optional[
862            Callable[[str, Headers], Awaitable[Optional[HTTPResponse]]]
863        ] = None,
864        select_subprotocol: Optional[
865            Callable[[Sequence[Subprotocol], Sequence[Subprotocol]], Subprotocol]
866        ] = None,
867        **kwargs: Any,
868    ) -> None:
869        # Backwards compatibility: close_timeout used to be called timeout.
870        if timeout is None:
871            timeout = 10
872        else:
873            warnings.warn("rename timeout to close_timeout", DeprecationWarning)
874        # If both are specified, timeout is ignored.
875        if close_timeout is None:
876            close_timeout = timeout
877
878        # Backwards compatibility: create_protocol used to be called klass.
879        if klass is None:
880            klass = WebSocketServerProtocol
881        else:
882            warnings.warn("rename klass to create_protocol", DeprecationWarning)
883        # If both are specified, klass is ignored.
884        if create_protocol is None:
885            create_protocol = klass
886
887        if loop is None:
888            loop = asyncio.get_event_loop()
889
890        ws_server = WebSocketServer(loop)
891
892        secure = kwargs.get("ssl") is not None
893
894        if compression == "deflate":
895            if extensions is None:
896                extensions = []
897            if not any(
898                ext_factory.name == ServerPerMessageDeflateFactory.name
899                for ext_factory in extensions
900            ):
901                extensions = list(extensions) + [ServerPerMessageDeflateFactory()]
902        elif compression is not None:
903            raise ValueError(f"unsupported compression: {compression}")
904
905        factory = functools.partial(
906            create_protocol,
907            ws_handler,
908            ws_server,
909            host=host,
910            port=port,
911            secure=secure,
912            ping_interval=ping_interval,
913            ping_timeout=ping_timeout,
914            close_timeout=close_timeout,
915            max_size=max_size,
916            max_queue=max_queue,
917            read_limit=read_limit,
918            write_limit=write_limit,
919            loop=loop,
920            legacy_recv=legacy_recv,
921            origins=origins,
922            extensions=extensions,
923            subprotocols=subprotocols,
924            extra_headers=extra_headers,
925            process_request=process_request,
926            select_subprotocol=select_subprotocol,
927        )
928
929        if path is None:
930            create_server = functools.partial(
931                loop.create_server, factory, host, port, **kwargs
932            )
933        else:
934            # unix_serve(path) must not specify host and port parameters.
935            assert host is None and port is None
936            create_server = functools.partial(
937                loop.create_unix_server, factory, path, **kwargs
938            )
939
940        # This is a coroutine function.
941        self._create_server = create_server
942        self.ws_server = ws_server
943
944    # async with serve(...)
945
946    async def __aenter__(self) -> WebSocketServer:
947        return await self
948
949    async def __aexit__(
950        self,
951        exc_type: Optional[Type[BaseException]],
952        exc_value: Optional[BaseException],
953        traceback: Optional[TracebackType],
954    ) -> None:
955        self.ws_server.close()
956        await self.ws_server.wait_closed()
957
958    # await serve(...)
959
960    def __await__(self) -> Generator[Any, None, WebSocketServer]:
961        # Create a suitable iterator by calling __await__ on a coroutine.
962        return self.__await_impl__().__await__()
963
964    async def __await_impl__(self) -> WebSocketServer:
965        server = await self._create_server()
966        self.ws_server.wrap(server)
967        return self.ws_server
968
969    # yield from serve(...)
970
971    __iter__ = __await__
972
973
974serve = Serve
975
976
977def unix_serve(
978    ws_handler: Callable[[WebSocketServerProtocol, str], Awaitable[Any]],
979    path: str,
980    **kwargs: Any,
981) -> Serve:
982    """
983    Similar to :func:`serve`, but for listening on Unix sockets.
984
985    This function calls the event loop's
986    :meth:`~asyncio.loop.create_unix_server` method.
987
988    It is only available on Unix.
989
990    It's useful for deploying a server behind a reverse proxy such as nginx.
991
992    :param path: file system path to the Unix socket
993
994    """
995    return serve(ws_handler, path=path, **kwargs)
996