1""" 2wsproto/handshake 3~~~~~~~~~~~~~~~~~~ 4 5An implementation of WebSocket handshakes. 6""" 7from collections import deque 8from typing import ( 9 cast, 10 Deque, 11 Dict, 12 Generator, 13 Iterable, 14 List, 15 Optional, 16 Sequence, 17 Union, 18) 19 20import h11 21 22from .connection import Connection, ConnectionState, ConnectionType 23from .events import AcceptConnection, Event, RejectConnection, RejectData, Request 24from .extensions import Extension 25from .typing import Headers 26from .utilities import ( 27 generate_accept_token, 28 generate_nonce, 29 LocalProtocolError, 30 normed_header_dict, 31 RemoteProtocolError, 32 split_comma_header, 33) 34 35# RFC6455, Section 4.2.1/6 - Reading the Client's Opening Handshake 36WEBSOCKET_VERSION = b"13" 37 38 39class H11Handshake: 40 """A Handshake implementation for HTTP/1.1 connections.""" 41 42 def __init__(self, connection_type: ConnectionType) -> None: 43 self.client = connection_type is ConnectionType.CLIENT 44 self._state = ConnectionState.CONNECTING 45 46 if self.client: 47 self._h11_connection = h11.Connection(h11.CLIENT) 48 else: 49 self._h11_connection = h11.Connection(h11.SERVER) 50 51 self._connection: Optional[Connection] = None 52 self._events: Deque[Event] = deque() 53 self._initiating_request: Optional[Request] = None 54 self._nonce: Optional[bytes] = None 55 56 @property 57 def state(self) -> ConnectionState: 58 return self._state 59 60 @property 61 def connection(self) -> Optional[Connection]: 62 """Return the established connection. 63 64 This will either return the connection or raise a 65 LocalProtocolError if the connection has not yet been 66 established. 67 68 :rtype: h11.Connection 69 """ 70 return self._connection 71 72 def initiate_upgrade_connection(self, headers: Headers, path: str) -> None: 73 """Initiate an upgrade connection. 74 75 This should be used if the request has already be received and 76 parsed. 77 78 :param list headers: HTTP headers represented as a list of 2-tuples. 79 :param str path: A URL path. 80 """ 81 if self.client: 82 raise LocalProtocolError( 83 "Cannot initiate an upgrade connection when acting as the client" 84 ) 85 upgrade_request = h11.Request(method=b"GET", target=path, headers=headers) 86 h11_client = h11.Connection(h11.CLIENT) 87 self.receive_data(h11_client.send(upgrade_request)) 88 89 def send(self, event: Event) -> bytes: 90 """Send an event to the remote. 91 92 This will return the bytes to send based on the event or raise 93 a LocalProtocolError if the event is not valid given the 94 state. 95 96 :returns: Data to send to the WebSocket peer. 97 :rtype: bytes 98 """ 99 data = b"" 100 if isinstance(event, Request): 101 data += self._initiate_connection(event) 102 elif isinstance(event, AcceptConnection): 103 data += self._accept(event) 104 elif isinstance(event, RejectConnection): 105 data += self._reject(event) 106 elif isinstance(event, RejectData): 107 data += self._send_reject_data(event) 108 else: 109 raise LocalProtocolError( 110 f"Event {event} cannot be sent during the handshake" 111 ) 112 return data 113 114 def receive_data(self, data: Optional[bytes]) -> None: 115 """Receive data from the remote. 116 117 A list of events that the remote peer triggered by sending 118 this data can be retrieved with :meth:`events`. 119 120 :param bytes data: Data received from the WebSocket peer. 121 """ 122 self._h11_connection.receive_data(data) 123 while True: 124 try: 125 event = self._h11_connection.next_event() 126 except h11.RemoteProtocolError: 127 raise RemoteProtocolError( 128 "Bad HTTP message", event_hint=RejectConnection() 129 ) 130 if ( 131 isinstance(event, h11.ConnectionClosed) 132 or event is h11.NEED_DATA 133 or event is h11.PAUSED 134 ): 135 break 136 137 if self.client: 138 if isinstance(event, h11.InformationalResponse): 139 if event.status_code == 101: 140 self._events.append(self._establish_client_connection(event)) 141 else: 142 self._events.append( 143 RejectConnection( 144 headers=event.headers, 145 status_code=event.status_code, 146 has_body=False, 147 ) 148 ) 149 self._state = ConnectionState.CLOSED 150 elif isinstance(event, h11.Response): 151 self._state = ConnectionState.REJECTING 152 self._events.append( 153 RejectConnection( 154 headers=event.headers, 155 status_code=event.status_code, 156 has_body=True, 157 ) 158 ) 159 elif isinstance(event, h11.Data): 160 self._events.append( 161 RejectData(data=event.data, body_finished=False) 162 ) 163 elif isinstance(event, h11.EndOfMessage): 164 self._events.append(RejectData(data=b"", body_finished=True)) 165 self._state = ConnectionState.CLOSED 166 else: 167 if isinstance(event, h11.Request): 168 self._events.append(self._process_connection_request(event)) 169 170 def events(self) -> Generator[Event, None, None]: 171 """Return a generator that provides any events that have been generated 172 by protocol activity. 173 174 :returns: a generator that yields H11 events. 175 """ 176 while self._events: 177 yield self._events.popleft() 178 179 ############ Server mode methods 180 181 def _process_connection_request( # noqa: MC0001 182 self, event: h11.Request 183 ) -> Request: 184 if event.method != b"GET": 185 raise RemoteProtocolError( 186 "Request method must be GET", event_hint=RejectConnection() 187 ) 188 connection_tokens = None 189 extensions: List[str] = [] 190 host = None 191 key = None 192 subprotocols: List[str] = [] 193 upgrade = b"" 194 version = None 195 headers: Headers = [] 196 for name, value in event.headers: 197 name = name.lower() 198 if name == b"connection": 199 connection_tokens = split_comma_header(value) 200 elif name == b"host": 201 host = value.decode("ascii") 202 continue # Skip appending to headers 203 elif name == b"sec-websocket-extensions": 204 extensions = split_comma_header(value) 205 continue # Skip appending to headers 206 elif name == b"sec-websocket-key": 207 key = value 208 elif name == b"sec-websocket-protocol": 209 subprotocols = split_comma_header(value) 210 continue # Skip appending to headers 211 elif name == b"sec-websocket-version": 212 version = value 213 elif name == b"upgrade": 214 upgrade = value 215 headers.append((name, value)) 216 if connection_tokens is None or not any( 217 token.lower() == "upgrade" for token in connection_tokens 218 ): 219 raise RemoteProtocolError( 220 "Missing header, 'Connection: Upgrade'", event_hint=RejectConnection() 221 ) 222 if version != WEBSOCKET_VERSION: 223 raise RemoteProtocolError( 224 "Missing header, 'Sec-WebSocket-Version'", 225 event_hint=RejectConnection( 226 headers=[(b"Sec-WebSocket-Version", WEBSOCKET_VERSION)], 227 status_code=426 if version else 400, 228 ), 229 ) 230 if key is None: 231 raise RemoteProtocolError( 232 "Missing header, 'Sec-WebSocket-Key'", event_hint=RejectConnection() 233 ) 234 if upgrade.lower() != b"websocket": 235 raise RemoteProtocolError( 236 "Missing header, 'Upgrade: WebSocket'", event_hint=RejectConnection() 237 ) 238 if host is None: 239 raise RemoteProtocolError( 240 "Missing header, 'Host'", event_hint=RejectConnection() 241 ) 242 243 self._initiating_request = Request( 244 extensions=extensions, 245 extra_headers=headers, 246 host=host, 247 subprotocols=subprotocols, 248 target=event.target.decode("ascii"), 249 ) 250 return self._initiating_request 251 252 def _accept(self, event: AcceptConnection) -> bytes: 253 # _accept is always called after _process_connection_request. 254 assert self._initiating_request is not None 255 request_headers = normed_header_dict(self._initiating_request.extra_headers) 256 257 nonce = request_headers[b"sec-websocket-key"] 258 accept_token = generate_accept_token(nonce) 259 260 headers = [ 261 (b"Upgrade", b"WebSocket"), 262 (b"Connection", b"Upgrade"), 263 (b"Sec-WebSocket-Accept", accept_token), 264 ] 265 266 if event.subprotocol is not None: 267 if event.subprotocol not in self._initiating_request.subprotocols: 268 raise LocalProtocolError(f"unexpected subprotocol {event.subprotocol}") 269 headers.append( 270 (b"Sec-WebSocket-Protocol", event.subprotocol.encode("ascii")) 271 ) 272 273 if event.extensions: 274 accepts = server_extensions_handshake( 275 cast(Sequence[str], self._initiating_request.extensions), 276 event.extensions, 277 ) 278 if accepts: 279 headers.append((b"Sec-WebSocket-Extensions", accepts)) 280 281 response = h11.InformationalResponse( 282 status_code=101, headers=headers + event.extra_headers 283 ) 284 self._connection = Connection( 285 ConnectionType.CLIENT if self.client else ConnectionType.SERVER, 286 event.extensions, 287 ) 288 self._state = ConnectionState.OPEN 289 return self._h11_connection.send(response) 290 291 def _reject(self, event: RejectConnection) -> bytes: 292 if self.state != ConnectionState.CONNECTING: 293 raise LocalProtocolError( 294 "Connection cannot be rejected in state %s" % self.state 295 ) 296 297 headers = event.headers 298 if not event.has_body: 299 headers.append((b"content-length", b"0")) 300 response = h11.Response(status_code=event.status_code, headers=headers) 301 data = self._h11_connection.send(response) 302 self._state = ConnectionState.REJECTING 303 if not event.has_body: 304 data += self._h11_connection.send(h11.EndOfMessage()) 305 self._state = ConnectionState.CLOSED 306 return data 307 308 def _send_reject_data(self, event: RejectData) -> bytes: 309 if self.state != ConnectionState.REJECTING: 310 raise LocalProtocolError( 311 f"Cannot send rejection data in state {self.state}" 312 ) 313 314 data = self._h11_connection.send(h11.Data(data=event.data)) 315 if event.body_finished: 316 data += self._h11_connection.send(h11.EndOfMessage()) 317 self._state = ConnectionState.CLOSED 318 return data 319 320 ############ Client mode methods 321 322 def _initiate_connection(self, request: Request) -> bytes: 323 self._initiating_request = request 324 self._nonce = generate_nonce() 325 326 headers = [ 327 (b"Host", request.host.encode("ascii")), 328 (b"Upgrade", b"WebSocket"), 329 (b"Connection", b"Upgrade"), 330 (b"Sec-WebSocket-Key", self._nonce), 331 (b"Sec-WebSocket-Version", WEBSOCKET_VERSION), 332 ] 333 334 if request.subprotocols: 335 headers.append( 336 ( 337 b"Sec-WebSocket-Protocol", 338 (", ".join(request.subprotocols)).encode("ascii"), 339 ) 340 ) 341 342 if request.extensions: 343 offers: Dict[str, Union[str, bool]] = {} 344 for e in request.extensions: 345 assert isinstance(e, Extension) 346 offers[e.name] = e.offer() 347 extensions = [] 348 for name, params in offers.items(): 349 bname = name.encode("ascii") 350 if isinstance(params, bool): 351 if params: 352 extensions.append(bname) 353 else: 354 extensions.append(b"%s; %s" % (bname, params.encode("ascii"))) 355 if extensions: 356 headers.append((b"Sec-WebSocket-Extensions", b", ".join(extensions))) 357 358 upgrade = h11.Request( 359 method=b"GET", 360 target=request.target.encode("ascii"), 361 headers=headers + request.extra_headers, 362 ) 363 return self._h11_connection.send(upgrade) 364 365 def _establish_client_connection( 366 self, event: h11.InformationalResponse 367 ) -> AcceptConnection: # noqa: MC0001 368 # _establish_client_connection is always called after _initiate_connection. 369 assert self._initiating_request is not None 370 assert self._nonce is not None 371 372 accept = None 373 connection_tokens = None 374 accepts: List[str] = [] 375 subprotocol = None 376 upgrade = b"" 377 headers: Headers = [] 378 for name, value in event.headers: 379 name = name.lower() 380 if name == b"connection": 381 connection_tokens = split_comma_header(value) 382 continue # Skip appending to headers 383 elif name == b"sec-websocket-extensions": 384 accepts = split_comma_header(value) 385 continue # Skip appending to headers 386 elif name == b"sec-websocket-accept": 387 accept = value 388 continue # Skip appending to headers 389 elif name == b"sec-websocket-protocol": 390 subprotocol = value 391 continue # Skip appending to headers 392 elif name == b"upgrade": 393 upgrade = value 394 continue # Skip appending to headers 395 headers.append((name, value)) 396 397 if connection_tokens is None or not any( 398 token.lower() == "upgrade" for token in connection_tokens 399 ): 400 raise RemoteProtocolError( 401 "Missing header, 'Connection: Upgrade'", event_hint=RejectConnection() 402 ) 403 if upgrade.lower() != b"websocket": 404 raise RemoteProtocolError( 405 "Missing header, 'Upgrade: WebSocket'", event_hint=RejectConnection() 406 ) 407 accept_token = generate_accept_token(self._nonce) 408 if accept != accept_token: 409 raise RemoteProtocolError("Bad accept token", event_hint=RejectConnection()) 410 if subprotocol is not None: 411 subprotocol = subprotocol.decode("ascii") 412 if subprotocol not in self._initiating_request.subprotocols: 413 raise RemoteProtocolError( 414 f"unrecognized subprotocol {subprotocol}", 415 event_hint=RejectConnection(), 416 ) 417 extensions = client_extensions_handshake( 418 accepts, cast(Sequence[Extension], self._initiating_request.extensions) 419 ) 420 421 self._connection = Connection( 422 ConnectionType.CLIENT if self.client else ConnectionType.SERVER, 423 extensions, 424 self._h11_connection.trailing_data[0], 425 ) 426 self._state = ConnectionState.OPEN 427 return AcceptConnection( 428 extensions=extensions, extra_headers=headers, subprotocol=subprotocol 429 ) 430 431 def __repr__(self) -> str: 432 return "{}(client={}, state={})".format( 433 self.__class__.__name__, self.client, self.state 434 ) 435 436 437def server_extensions_handshake( 438 requested: Iterable[str], supported: List[Extension] 439) -> Optional[bytes]: 440 """Agree on the extensions to use returning an appropriate header value. 441 442 This returns None if there are no agreed extensions 443 """ 444 accepts: Dict[str, Union[bool, bytes]] = {} 445 for offer in requested: 446 name = offer.split(";", 1)[0].strip() 447 for extension in supported: 448 if extension.name == name: 449 accept = extension.accept(offer) 450 if isinstance(accept, bool): 451 if accept: 452 accepts[extension.name] = True 453 elif accept is not None: 454 accepts[extension.name] = accept.encode("ascii") 455 456 if accepts: 457 extensions: List[bytes] = [] 458 for name, params in accepts.items(): 459 name_bytes = name.encode("ascii") 460 if isinstance(params, bool): 461 assert params 462 extensions.append(name_bytes) 463 else: 464 if params == b"": 465 extensions.append(b"%s" % (name_bytes)) 466 else: 467 extensions.append(b"%s; %s" % (name_bytes, params)) 468 return b", ".join(extensions) 469 470 return None 471 472 473def client_extensions_handshake( 474 accepted: Iterable[str], supported: Sequence[Extension] 475) -> List[Extension]: 476 # This raises RemoteProtocolError is the accepted extension is not 477 # supported. 478 extensions = [] 479 for accept in accepted: 480 name = accept.split(";", 1)[0].strip() 481 for extension in supported: 482 if extension.name == name: 483 extension.finalize(accept) 484 extensions.append(extension) 485 break 486 else: 487 raise RemoteProtocolError( 488 f"unrecognized extension {name}", event_hint=RejectConnection() 489 ) 490 return extensions 491