1import asyncio 2import http 3import inspect 4import logging 5from typing import TYPE_CHECKING, Callable 6from urllib.parse import unquote 7 8import websockets 9from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory 10 11from uvicorn.logging import TRACE_LOG_LEVEL 12from uvicorn.protocols.utils import get_local_addr, get_remote_addr, is_ssl 13 14 15class Server: 16 closing = False 17 18 def register(self, ws): 19 pass 20 21 def unregister(self, ws): 22 pass 23 24 def is_serving(self): 25 return not self.closing 26 27 28# special case logger kwarg in websockets >=10 29# https://github.com/aaugustin/websockets/issues/1021#issuecomment-886222136 30if ( 31 TYPE_CHECKING 32 or "logger" in inspect.signature(websockets.WebSocketServerProtocol).parameters 33): 34 35 class _LoggerMixin: 36 pass 37 38 39else: 40 41 class _LoggerMixin: 42 def __init__(self, *args, logger, **kwargs): 43 super().__init__(*args, **kwargs) 44 self.logger = logging.LoggerAdapter(logger, {"websocket": self}) 45 46 47class WebSocketProtocol(_LoggerMixin, websockets.WebSocketServerProtocol): 48 def __init__( 49 self, config, server_state, on_connection_lost: Callable = None, _loop=None 50 ): 51 if not config.loaded: 52 config.load() 53 54 self.config = config 55 self.app = config.loaded_app 56 self.on_connection_lost = on_connection_lost 57 self.loop = _loop or asyncio.get_event_loop() 58 self.root_path = config.root_path 59 60 # Shared server state 61 self.connections = server_state.connections 62 self.tasks = server_state.tasks 63 64 # Connection state 65 self.transport = None 66 self.server = None 67 self.client = None 68 self.scheme = None 69 70 # Connection events 71 self.scope = None 72 self.handshake_started_event = asyncio.Event() 73 self.handshake_completed_event = asyncio.Event() 74 self.closed_event = asyncio.Event() 75 self.initial_response = None 76 self.connect_sent = False 77 self.accepted_subprotocol = None 78 self.transfer_data_task = None 79 80 self.ws_server = Server() 81 super().__init__( 82 ws_handler=self.ws_handler, 83 ws_server=self.ws_server, 84 max_size=self.config.ws_max_size, 85 ping_interval=self.config.ws_ping_interval, 86 ping_timeout=self.config.ws_ping_timeout, 87 extensions=[ServerPerMessageDeflateFactory()], 88 logger=logging.getLogger("uvicorn.error"), 89 ) 90 91 def connection_made(self, transport): 92 self.connections.add(self) 93 self.transport = transport 94 self.server = get_local_addr(transport) 95 self.client = get_remote_addr(transport) 96 self.scheme = "wss" if is_ssl(transport) else "ws" 97 98 if self.logger.isEnabledFor(TRACE_LOG_LEVEL): 99 prefix = "%s:%d - " % tuple(self.client) if self.client else "" 100 self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) 101 102 super().connection_made(transport) 103 104 def connection_lost(self, exc): 105 self.connections.remove(self) 106 107 if self.logger.isEnabledFor(TRACE_LOG_LEVEL): 108 prefix = "%s:%d - " % tuple(self.client) if self.client else "" 109 self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix) 110 111 self.handshake_completed_event.set() 112 super().connection_lost(exc) 113 if self.on_connection_lost is not None: 114 self.on_connection_lost() 115 if exc is None: 116 self.transport.close() 117 118 def shutdown(self): 119 self.ws_server.closing = True 120 self.transport.close() 121 122 def on_task_complete(self, task): 123 self.tasks.discard(task) 124 125 async def process_request(self, path, headers): 126 """ 127 This hook is called to determine if the websocket should return 128 an HTTP response and close. 129 130 Our behavior here is to start the ASGI application, and then wait 131 for either `accept` or `close` in order to determine if we should 132 close the connection. 133 """ 134 path_portion, _, query_string = path.partition("?") 135 136 websockets.legacy.handshake.check_request(headers) 137 138 subprotocols = [] 139 for header in headers.get_all("Sec-WebSocket-Protocol"): 140 subprotocols.extend([token.strip() for token in header.split(",")]) 141 142 asgi_headers = [ 143 (name.encode("ascii"), value.encode("ascii")) 144 for name, value in headers.raw_items() 145 ] 146 147 self.scope = { 148 "type": "websocket", 149 "asgi": {"version": self.config.asgi_version, "spec_version": "2.1"}, 150 "scheme": self.scheme, 151 "server": self.server, 152 "client": self.client, 153 "root_path": self.root_path, 154 "path": unquote(path_portion), 155 "raw_path": path_portion, 156 "query_string": query_string.encode("ascii"), 157 "headers": asgi_headers, 158 "subprotocols": subprotocols, 159 } 160 task = self.loop.create_task(self.run_asgi()) 161 task.add_done_callback(self.on_task_complete) 162 self.tasks.add(task) 163 await self.handshake_started_event.wait() 164 return self.initial_response 165 166 def process_subprotocol(self, headers, available_subprotocols): 167 """ 168 We override the standard 'process_subprotocol' behavior here so that 169 we return whatever subprotocol is sent in the 'accept' message. 170 """ 171 return self.accepted_subprotocol 172 173 def send_500_response(self): 174 msg = b"Internal Server Error" 175 content = [ 176 b"HTTP/1.1 500 Internal Server Error\r\n" 177 b"content-type: text/plain; charset=utf-8\r\n", 178 b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n", 179 b"connection: close\r\n", 180 b"\r\n", 181 msg, 182 ] 183 self.transport.write(b"".join(content)) 184 # Allow handler task to terminate cleanly, as websockets doesn't cancel it by 185 # itself (see https://github.com/encode/uvicorn/issues/920) 186 self.handshake_started_event.set() 187 188 async def ws_handler(self, protocol, path): 189 """ 190 This is the main handler function for the 'websockets' implementation 191 to call into. We just wait for close then return, and instead allow 192 'send' and 'receive' events to drive the flow. 193 """ 194 self.handshake_completed_event.set() 195 await self.closed_event.wait() 196 197 async def run_asgi(self): 198 """ 199 Wrapper around the ASGI callable, handling exceptions and unexpected 200 termination states. 201 """ 202 try: 203 result = await self.app(self.scope, self.asgi_receive, self.asgi_send) 204 except BaseException as exc: 205 self.closed_event.set() 206 msg = "Exception in ASGI application\n" 207 self.logger.error(msg, exc_info=exc) 208 if not self.handshake_started_event.is_set(): 209 self.send_500_response() 210 else: 211 await self.handshake_completed_event.wait() 212 self.transport.close() 213 else: 214 self.closed_event.set() 215 if not self.handshake_started_event.is_set(): 216 msg = "ASGI callable returned without sending handshake." 217 self.logger.error(msg) 218 self.send_500_response() 219 self.transport.close() 220 elif result is not None: 221 msg = "ASGI callable should return None, but returned '%s'." 222 self.logger.error(msg, result) 223 await self.handshake_completed_event.wait() 224 self.transport.close() 225 226 async def asgi_send(self, message): 227 message_type = message["type"] 228 229 if not self.handshake_started_event.is_set(): 230 if message_type == "websocket.accept": 231 self.logger.info( 232 '%s - "WebSocket %s" [accepted]', 233 self.scope["client"], 234 self.scope["root_path"] + self.scope["path"], 235 ) 236 self.initial_response = None 237 self.accepted_subprotocol = message.get("subprotocol") 238 self.handshake_started_event.set() 239 240 elif message_type == "websocket.close": 241 self.logger.info( 242 '%s - "WebSocket %s" 403', 243 self.scope["client"], 244 self.scope["root_path"] + self.scope["path"], 245 ) 246 self.initial_response = (http.HTTPStatus.FORBIDDEN, [], b"") 247 self.handshake_started_event.set() 248 self.closed_event.set() 249 250 else: 251 msg = ( 252 "Expected ASGI message 'websocket.accept' or 'websocket.close', " 253 "but got '%s'." 254 ) 255 raise RuntimeError(msg % message_type) 256 257 elif not self.closed_event.is_set(): 258 await self.handshake_completed_event.wait() 259 260 if message_type == "websocket.send": 261 bytes_data = message.get("bytes") 262 text_data = message.get("text") 263 data = text_data if bytes_data is None else bytes_data 264 await self.send(data) 265 266 elif message_type == "websocket.close": 267 code = message.get("code", 1000) 268 reason = message.get("reason", "") 269 await self.close(code, reason) 270 self.closed_event.set() 271 272 else: 273 msg = ( 274 "Expected ASGI message 'websocket.send' or 'websocket.close'," 275 " but got '%s'." 276 ) 277 raise RuntimeError(msg % message_type) 278 279 else: 280 msg = "Unexpected ASGI message '%s', after sending 'websocket.close'." 281 raise RuntimeError(msg % message_type) 282 283 async def asgi_receive(self): 284 if not self.connect_sent: 285 self.connect_sent = True 286 return {"type": "websocket.connect"} 287 288 await self.handshake_completed_event.wait() 289 try: 290 await self.ensure_open() 291 data = await self.recv() 292 except websockets.ConnectionClosed as exc: 293 return {"type": "websocket.disconnect", "code": exc.code} 294 295 msg = {"type": "websocket.receive"} 296 297 if isinstance(data, str): 298 msg["text"] = data 299 else: 300 msg["bytes"] = data 301 302 return msg 303