1import asyncio
2import http
3import inspect
4import logging
5from typing import TYPE_CHECKING, Callable
6from urllib.parse import unquote
7
8import websockets
9from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
10
11from uvicorn.logging import TRACE_LOG_LEVEL
12from uvicorn.protocols.utils import get_local_addr, get_remote_addr, is_ssl
13
14
15class Server:
16    closing = False
17
18    def register(self, ws):
19        pass
20
21    def unregister(self, ws):
22        pass
23
24    def is_serving(self):
25        return not self.closing
26
27
28# special case logger kwarg in websockets >=10
29# https://github.com/aaugustin/websockets/issues/1021#issuecomment-886222136
30if (
31    TYPE_CHECKING
32    or "logger" in inspect.signature(websockets.WebSocketServerProtocol).parameters
33):
34
35    class _LoggerMixin:
36        pass
37
38
39else:
40
41    class _LoggerMixin:
42        def __init__(self, *args, logger, **kwargs):
43            super().__init__(*args, **kwargs)
44            self.logger = logging.LoggerAdapter(logger, {"websocket": self})
45
46
47class WebSocketProtocol(_LoggerMixin, websockets.WebSocketServerProtocol):
48    def __init__(
49        self, config, server_state, on_connection_lost: Callable = None, _loop=None
50    ):
51        if not config.loaded:
52            config.load()
53
54        self.config = config
55        self.app = config.loaded_app
56        self.on_connection_lost = on_connection_lost
57        self.loop = _loop or asyncio.get_event_loop()
58        self.root_path = config.root_path
59
60        # Shared server state
61        self.connections = server_state.connections
62        self.tasks = server_state.tasks
63
64        # Connection state
65        self.transport = None
66        self.server = None
67        self.client = None
68        self.scheme = None
69
70        # Connection events
71        self.scope = None
72        self.handshake_started_event = asyncio.Event()
73        self.handshake_completed_event = asyncio.Event()
74        self.closed_event = asyncio.Event()
75        self.initial_response = None
76        self.connect_sent = False
77        self.accepted_subprotocol = None
78        self.transfer_data_task = None
79
80        self.ws_server = Server()
81        super().__init__(
82            ws_handler=self.ws_handler,
83            ws_server=self.ws_server,
84            max_size=self.config.ws_max_size,
85            ping_interval=self.config.ws_ping_interval,
86            ping_timeout=self.config.ws_ping_timeout,
87            extensions=[ServerPerMessageDeflateFactory()],
88            logger=logging.getLogger("uvicorn.error"),
89        )
90
91    def connection_made(self, transport):
92        self.connections.add(self)
93        self.transport = transport
94        self.server = get_local_addr(transport)
95        self.client = get_remote_addr(transport)
96        self.scheme = "wss" if is_ssl(transport) else "ws"
97
98        if self.logger.isEnabledFor(TRACE_LOG_LEVEL):
99            prefix = "%s:%d - " % tuple(self.client) if self.client else ""
100            self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)
101
102        super().connection_made(transport)
103
104    def connection_lost(self, exc):
105        self.connections.remove(self)
106
107        if self.logger.isEnabledFor(TRACE_LOG_LEVEL):
108            prefix = "%s:%d - " % tuple(self.client) if self.client else ""
109            self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)
110
111        self.handshake_completed_event.set()
112        super().connection_lost(exc)
113        if self.on_connection_lost is not None:
114            self.on_connection_lost()
115        if exc is None:
116            self.transport.close()
117
118    def shutdown(self):
119        self.ws_server.closing = True
120        self.transport.close()
121
122    def on_task_complete(self, task):
123        self.tasks.discard(task)
124
125    async def process_request(self, path, headers):
126        """
127        This hook is called to determine if the websocket should return
128        an HTTP response and close.
129
130        Our behavior here is to start the ASGI application, and then wait
131        for either `accept` or `close` in order to determine if we should
132        close the connection.
133        """
134        path_portion, _, query_string = path.partition("?")
135
136        websockets.legacy.handshake.check_request(headers)
137
138        subprotocols = []
139        for header in headers.get_all("Sec-WebSocket-Protocol"):
140            subprotocols.extend([token.strip() for token in header.split(",")])
141
142        asgi_headers = [
143            (name.encode("ascii"), value.encode("ascii"))
144            for name, value in headers.raw_items()
145        ]
146
147        self.scope = {
148            "type": "websocket",
149            "asgi": {"version": self.config.asgi_version, "spec_version": "2.1"},
150            "scheme": self.scheme,
151            "server": self.server,
152            "client": self.client,
153            "root_path": self.root_path,
154            "path": unquote(path_portion),
155            "raw_path": path_portion,
156            "query_string": query_string.encode("ascii"),
157            "headers": asgi_headers,
158            "subprotocols": subprotocols,
159        }
160        task = self.loop.create_task(self.run_asgi())
161        task.add_done_callback(self.on_task_complete)
162        self.tasks.add(task)
163        await self.handshake_started_event.wait()
164        return self.initial_response
165
166    def process_subprotocol(self, headers, available_subprotocols):
167        """
168        We override the standard 'process_subprotocol' behavior here so that
169        we return whatever subprotocol is sent in the 'accept' message.
170        """
171        return self.accepted_subprotocol
172
173    def send_500_response(self):
174        msg = b"Internal Server Error"
175        content = [
176            b"HTTP/1.1 500 Internal Server Error\r\n"
177            b"content-type: text/plain; charset=utf-8\r\n",
178            b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n",
179            b"connection: close\r\n",
180            b"\r\n",
181            msg,
182        ]
183        self.transport.write(b"".join(content))
184        # Allow handler task to terminate cleanly, as websockets doesn't cancel it by
185        # itself (see https://github.com/encode/uvicorn/issues/920)
186        self.handshake_started_event.set()
187
188    async def ws_handler(self, protocol, path):
189        """
190        This is the main handler function for the 'websockets' implementation
191        to call into. We just wait for close then return, and instead allow
192        'send' and 'receive' events to drive the flow.
193        """
194        self.handshake_completed_event.set()
195        await self.closed_event.wait()
196
197    async def run_asgi(self):
198        """
199        Wrapper around the ASGI callable, handling exceptions and unexpected
200        termination states.
201        """
202        try:
203            result = await self.app(self.scope, self.asgi_receive, self.asgi_send)
204        except BaseException as exc:
205            self.closed_event.set()
206            msg = "Exception in ASGI application\n"
207            self.logger.error(msg, exc_info=exc)
208            if not self.handshake_started_event.is_set():
209                self.send_500_response()
210            else:
211                await self.handshake_completed_event.wait()
212            self.transport.close()
213        else:
214            self.closed_event.set()
215            if not self.handshake_started_event.is_set():
216                msg = "ASGI callable returned without sending handshake."
217                self.logger.error(msg)
218                self.send_500_response()
219                self.transport.close()
220            elif result is not None:
221                msg = "ASGI callable should return None, but returned '%s'."
222                self.logger.error(msg, result)
223                await self.handshake_completed_event.wait()
224                self.transport.close()
225
226    async def asgi_send(self, message):
227        message_type = message["type"]
228
229        if not self.handshake_started_event.is_set():
230            if message_type == "websocket.accept":
231                self.logger.info(
232                    '%s - "WebSocket %s" [accepted]',
233                    self.scope["client"],
234                    self.scope["root_path"] + self.scope["path"],
235                )
236                self.initial_response = None
237                self.accepted_subprotocol = message.get("subprotocol")
238                self.handshake_started_event.set()
239
240            elif message_type == "websocket.close":
241                self.logger.info(
242                    '%s - "WebSocket %s" 403',
243                    self.scope["client"],
244                    self.scope["root_path"] + self.scope["path"],
245                )
246                self.initial_response = (http.HTTPStatus.FORBIDDEN, [], b"")
247                self.handshake_started_event.set()
248                self.closed_event.set()
249
250            else:
251                msg = (
252                    "Expected ASGI message 'websocket.accept' or 'websocket.close', "
253                    "but got '%s'."
254                )
255                raise RuntimeError(msg % message_type)
256
257        elif not self.closed_event.is_set():
258            await self.handshake_completed_event.wait()
259
260            if message_type == "websocket.send":
261                bytes_data = message.get("bytes")
262                text_data = message.get("text")
263                data = text_data if bytes_data is None else bytes_data
264                await self.send(data)
265
266            elif message_type == "websocket.close":
267                code = message.get("code", 1000)
268                reason = message.get("reason", "")
269                await self.close(code, reason)
270                self.closed_event.set()
271
272            else:
273                msg = (
274                    "Expected ASGI message 'websocket.send' or 'websocket.close',"
275                    " but got '%s'."
276                )
277                raise RuntimeError(msg % message_type)
278
279        else:
280            msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
281            raise RuntimeError(msg % message_type)
282
283    async def asgi_receive(self):
284        if not self.connect_sent:
285            self.connect_sent = True
286            return {"type": "websocket.connect"}
287
288        await self.handshake_completed_event.wait()
289        try:
290            await self.ensure_open()
291            data = await self.recv()
292        except websockets.ConnectionClosed as exc:
293            return {"type": "websocket.disconnect", "code": exc.code}
294
295        msg = {"type": "websocket.receive"}
296
297        if isinstance(data, str):
298            msg["text"] = data
299        else:
300            msg["bytes"] = data
301
302        return msg
303