1""" 2:mod:`websockets.client` defines the WebSocket client APIs. 3 4""" 5 6import asyncio 7import collections.abc 8import functools 9import logging 10import warnings 11from types import TracebackType 12from typing import Any, Generator, List, Optional, Sequence, Tuple, Type, cast 13 14from .exceptions import ( 15 InvalidHandshake, 16 InvalidHeader, 17 InvalidMessage, 18 InvalidStatusCode, 19 NegotiationError, 20 RedirectHandshake, 21 SecurityError, 22) 23from .extensions.base import ClientExtensionFactory, Extension 24from .extensions.permessage_deflate import ClientPerMessageDeflateFactory 25from .handshake import build_request, check_response 26from .headers import ( 27 build_authorization_basic, 28 build_extension, 29 build_subprotocol, 30 parse_extension, 31 parse_subprotocol, 32) 33from .http import USER_AGENT, Headers, HeadersLike, read_response 34from .protocol import WebSocketCommonProtocol 35from .typing import ExtensionHeader, Origin, Subprotocol 36from .uri import WebSocketURI, parse_uri 37 38 39__all__ = ["connect", "unix_connect", "WebSocketClientProtocol"] 40 41logger = logging.getLogger(__name__) 42 43 44class WebSocketClientProtocol(WebSocketCommonProtocol): 45 """ 46 :class:`~asyncio.Protocol` subclass implementing a WebSocket client. 47 48 This class inherits most of its methods from 49 :class:`~websockets.protocol.WebSocketCommonProtocol`. 50 51 """ 52 53 is_client = True 54 side = "client" 55 56 def __init__( 57 self, 58 *, 59 origin: Optional[Origin] = None, 60 extensions: Optional[Sequence[ClientExtensionFactory]] = None, 61 subprotocols: Optional[Sequence[Subprotocol]] = None, 62 extra_headers: Optional[HeadersLike] = None, 63 **kwargs: Any, 64 ) -> None: 65 self.origin = origin 66 self.available_extensions = extensions 67 self.available_subprotocols = subprotocols 68 self.extra_headers = extra_headers 69 super().__init__(**kwargs) 70 71 def write_http_request(self, path: str, headers: Headers) -> None: 72 """ 73 Write request line and headers to the HTTP request. 74 75 """ 76 self.path = path 77 self.request_headers = headers 78 79 logger.debug("%s > GET %s HTTP/1.1", self.side, path) 80 logger.debug("%s > %r", self.side, headers) 81 82 # Since the path and headers only contain ASCII characters, 83 # we can keep this simple. 84 request = f"GET {path} HTTP/1.1\r\n" 85 request += str(headers) 86 87 self.transport.write(request.encode()) 88 89 async def read_http_response(self) -> Tuple[int, Headers]: 90 """ 91 Read status line and headers from the HTTP response. 92 93 If the response contains a body, it may be read from ``self.reader`` 94 after this coroutine returns. 95 96 :raises ~websockets.exceptions.InvalidMessage: if the HTTP message is 97 malformed or isn't an HTTP/1.1 GET response 98 99 """ 100 try: 101 status_code, reason, headers = await read_response(self.reader) 102 except Exception as exc: 103 raise InvalidMessage("did not receive a valid HTTP response") from exc 104 105 logger.debug("%s < HTTP/1.1 %d %s", self.side, status_code, reason) 106 logger.debug("%s < %r", self.side, headers) 107 108 self.response_headers = headers 109 110 return status_code, self.response_headers 111 112 @staticmethod 113 def process_extensions( 114 headers: Headers, 115 available_extensions: Optional[Sequence[ClientExtensionFactory]], 116 ) -> List[Extension]: 117 """ 118 Handle the Sec-WebSocket-Extensions HTTP response header. 119 120 Check that each extension is supported, as well as its parameters. 121 122 Return the list of accepted extensions. 123 124 Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the 125 connection. 126 127 :rfc:`6455` leaves the rules up to the specification of each 128 :extension. 129 130 To provide this level of flexibility, for each extension accepted by 131 the server, we check for a match with each extension available in the 132 client configuration. If no match is found, an exception is raised. 133 134 If several variants of the same extension are accepted by the server, 135 it may be configured severel times, which won't make sense in general. 136 Extensions must implement their own requirements. For this purpose, 137 the list of previously accepted extensions is provided. 138 139 Other requirements, for example related to mandatory extensions or the 140 order of extensions, may be implemented by overriding this method. 141 142 """ 143 accepted_extensions: List[Extension] = [] 144 145 header_values = headers.get_all("Sec-WebSocket-Extensions") 146 147 if header_values: 148 149 if available_extensions is None: 150 raise InvalidHandshake("no extensions supported") 151 152 parsed_header_values: List[ExtensionHeader] = sum( 153 [parse_extension(header_value) for header_value in header_values], [] 154 ) 155 156 for name, response_params in parsed_header_values: 157 158 for extension_factory in available_extensions: 159 160 # Skip non-matching extensions based on their name. 161 if extension_factory.name != name: 162 continue 163 164 # Skip non-matching extensions based on their params. 165 try: 166 extension = extension_factory.process_response_params( 167 response_params, accepted_extensions 168 ) 169 except NegotiationError: 170 continue 171 172 # Add matching extension to the final list. 173 accepted_extensions.append(extension) 174 175 # Break out of the loop once we have a match. 176 break 177 178 # If we didn't break from the loop, no extension in our list 179 # matched what the server sent. Fail the connection. 180 else: 181 raise NegotiationError( 182 f"Unsupported extension: " 183 f"name = {name}, params = {response_params}" 184 ) 185 186 return accepted_extensions 187 188 @staticmethod 189 def process_subprotocol( 190 headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]] 191 ) -> Optional[Subprotocol]: 192 """ 193 Handle the Sec-WebSocket-Protocol HTTP response header. 194 195 Check that it contains exactly one supported subprotocol. 196 197 Return the selected subprotocol. 198 199 """ 200 subprotocol: Optional[Subprotocol] = None 201 202 header_values = headers.get_all("Sec-WebSocket-Protocol") 203 204 if header_values: 205 206 if available_subprotocols is None: 207 raise InvalidHandshake("no subprotocols supported") 208 209 parsed_header_values: Sequence[Subprotocol] = sum( 210 [parse_subprotocol(header_value) for header_value in header_values], [] 211 ) 212 213 if len(parsed_header_values) > 1: 214 subprotocols = ", ".join(parsed_header_values) 215 raise InvalidHandshake(f"multiple subprotocols: {subprotocols}") 216 217 subprotocol = parsed_header_values[0] 218 219 if subprotocol not in available_subprotocols: 220 raise NegotiationError(f"unsupported subprotocol: {subprotocol}") 221 222 return subprotocol 223 224 async def handshake( 225 self, 226 wsuri: WebSocketURI, 227 origin: Optional[Origin] = None, 228 available_extensions: Optional[Sequence[ClientExtensionFactory]] = None, 229 available_subprotocols: Optional[Sequence[Subprotocol]] = None, 230 extra_headers: Optional[HeadersLike] = None, 231 ) -> None: 232 """ 233 Perform the client side of the opening handshake. 234 235 :param origin: sets the Origin HTTP header 236 :param available_extensions: list of supported extensions in the order 237 in which they should be used 238 :param available_subprotocols: list of supported subprotocols in order 239 of decreasing preference 240 :param extra_headers: sets additional HTTP request headers; it must be 241 a :class:`~websockets.http.Headers` instance, a 242 :class:`~collections.abc.Mapping`, or an iterable of ``(name, 243 value)`` pairs 244 :raises ~websockets.exceptions.InvalidHandshake: if the handshake 245 fails 246 247 """ 248 request_headers = Headers() 249 250 if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover 251 request_headers["Host"] = wsuri.host 252 else: 253 request_headers["Host"] = f"{wsuri.host}:{wsuri.port}" 254 255 if wsuri.user_info: 256 request_headers["Authorization"] = build_authorization_basic( 257 *wsuri.user_info 258 ) 259 260 if origin is not None: 261 request_headers["Origin"] = origin 262 263 key = build_request(request_headers) 264 265 if available_extensions is not None: 266 extensions_header = build_extension( 267 [ 268 (extension_factory.name, extension_factory.get_request_params()) 269 for extension_factory in available_extensions 270 ] 271 ) 272 request_headers["Sec-WebSocket-Extensions"] = extensions_header 273 274 if available_subprotocols is not None: 275 protocol_header = build_subprotocol(available_subprotocols) 276 request_headers["Sec-WebSocket-Protocol"] = protocol_header 277 278 if extra_headers is not None: 279 if isinstance(extra_headers, Headers): 280 extra_headers = extra_headers.raw_items() 281 elif isinstance(extra_headers, collections.abc.Mapping): 282 extra_headers = extra_headers.items() 283 for name, value in extra_headers: 284 request_headers[name] = value 285 286 request_headers.setdefault("User-Agent", USER_AGENT) 287 288 self.write_http_request(wsuri.resource_name, request_headers) 289 290 status_code, response_headers = await self.read_http_response() 291 if status_code in (301, 302, 303, 307, 308): 292 if "Location" not in response_headers: 293 raise InvalidHeader("Location") 294 raise RedirectHandshake(response_headers["Location"]) 295 elif status_code != 101: 296 raise InvalidStatusCode(status_code) 297 298 check_response(response_headers, key) 299 300 self.extensions = self.process_extensions( 301 response_headers, available_extensions 302 ) 303 304 self.subprotocol = self.process_subprotocol( 305 response_headers, available_subprotocols 306 ) 307 308 self.connection_open() 309 310 311class Connect: 312 """ 313 Connect to the WebSocket server at the given ``uri``. 314 315 Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which 316 can then be used to send and receive messages. 317 318 :func:`connect` can also be used as a asynchronous context manager. In 319 that case, the connection is closed when exiting the context. 320 321 :func:`connect` is a wrapper around the event loop's 322 :meth:`~asyncio.loop.create_connection` method. Unknown keyword arguments 323 are passed to :meth:`~asyncio.loop.create_connection`. 324 325 For example, you can set the ``ssl`` keyword argument to a 326 :class:`~ssl.SSLContext` to enforce some TLS settings. When connecting to 327 a ``wss://`` URI, if this argument isn't provided explicitly, 328 :func:`ssl.create_default_context` is called to create a context. 329 330 You can connect to a different host and port from those found in ``uri`` 331 by setting ``host`` and ``port`` keyword arguments. This only changes the 332 destination of the TCP connection. The host name from ``uri`` is still 333 used in the TLS handshake for secure connections and in the ``Host`` HTTP 334 header. 335 336 The ``create_protocol`` parameter allows customizing the 337 :class:`~asyncio.Protocol` that manages the connection. It should be a 338 callable or class accepting the same arguments as 339 :class:`WebSocketClientProtocol` and returning an instance of 340 :class:`WebSocketClientProtocol` or a subclass. It defaults to 341 :class:`WebSocketClientProtocol`. 342 343 The behavior of ``ping_interval``, ``ping_timeout``, ``close_timeout``, 344 ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` is 345 described in :class:`~websockets.protocol.WebSocketCommonProtocol`. 346 347 :func:`connect` also accepts the following optional arguments: 348 349 * ``compression`` is a shortcut to configure compression extensions; 350 by default it enables the "permessage-deflate" extension; set it to 351 ``None`` to disable compression 352 * ``origin`` sets the Origin HTTP header 353 * ``extensions`` is a list of supported extensions in order of 354 decreasing preference 355 * ``subprotocols`` is a list of supported subprotocols in order of 356 decreasing preference 357 * ``extra_headers`` sets additional HTTP request headers; it can be a 358 :class:`~websockets.http.Headers` instance, a 359 :class:`~collections.abc.Mapping`, or an iterable of ``(name, value)`` 360 pairs 361 362 :raises ~websockets.uri.InvalidURI: if ``uri`` is invalid 363 :raises ~websockets.handshake.InvalidHandshake: if the opening handshake 364 fails 365 366 """ 367 368 MAX_REDIRECTS_ALLOWED = 10 369 370 def __init__( 371 self, 372 uri: str, 373 *, 374 path: Optional[str] = None, 375 create_protocol: Optional[Type[WebSocketClientProtocol]] = None, 376 ping_interval: float = 20, 377 ping_timeout: float = 20, 378 close_timeout: Optional[float] = None, 379 max_size: int = 2 ** 20, 380 max_queue: int = 2 ** 5, 381 read_limit: int = 2 ** 16, 382 write_limit: int = 2 ** 16, 383 loop: Optional[asyncio.AbstractEventLoop] = None, 384 legacy_recv: bool = False, 385 klass: Optional[Type[WebSocketClientProtocol]] = None, 386 timeout: Optional[float] = None, 387 compression: Optional[str] = "deflate", 388 origin: Optional[Origin] = None, 389 extensions: Optional[Sequence[ClientExtensionFactory]] = None, 390 subprotocols: Optional[Sequence[Subprotocol]] = None, 391 extra_headers: Optional[HeadersLike] = None, 392 **kwargs: Any, 393 ) -> None: 394 # Backwards compatibility: close_timeout used to be called timeout. 395 if timeout is None: 396 timeout = 10 397 else: 398 warnings.warn("rename timeout to close_timeout", DeprecationWarning) 399 # If both are specified, timeout is ignored. 400 if close_timeout is None: 401 close_timeout = timeout 402 403 # Backwards compatibility: create_protocol used to be called klass. 404 if klass is None: 405 klass = WebSocketClientProtocol 406 else: 407 warnings.warn("rename klass to create_protocol", DeprecationWarning) 408 # If both are specified, klass is ignored. 409 if create_protocol is None: 410 create_protocol = klass 411 412 if loop is None: 413 loop = asyncio.get_event_loop() 414 415 wsuri = parse_uri(uri) 416 if wsuri.secure: 417 kwargs.setdefault("ssl", True) 418 elif kwargs.get("ssl") is not None: 419 raise ValueError( 420 "connect() received a ssl argument for a ws:// URI, " 421 "use a wss:// URI to enable TLS" 422 ) 423 424 if compression == "deflate": 425 if extensions is None: 426 extensions = [] 427 if not any( 428 extension_factory.name == ClientPerMessageDeflateFactory.name 429 for extension_factory in extensions 430 ): 431 extensions = list(extensions) + [ 432 ClientPerMessageDeflateFactory(client_max_window_bits=True) 433 ] 434 elif compression is not None: 435 raise ValueError(f"unsupported compression: {compression}") 436 437 factory = functools.partial( 438 create_protocol, 439 ping_interval=ping_interval, 440 ping_timeout=ping_timeout, 441 close_timeout=close_timeout, 442 max_size=max_size, 443 max_queue=max_queue, 444 read_limit=read_limit, 445 write_limit=write_limit, 446 loop=loop, 447 host=wsuri.host, 448 port=wsuri.port, 449 secure=wsuri.secure, 450 legacy_recv=legacy_recv, 451 origin=origin, 452 extensions=extensions, 453 subprotocols=subprotocols, 454 extra_headers=extra_headers, 455 ) 456 457 if path is None: 458 host: Optional[str] 459 port: Optional[int] 460 if kwargs.get("sock") is None: 461 host, port = wsuri.host, wsuri.port 462 else: 463 # If sock is given, host and port shouldn't be specified. 464 host, port = None, None 465 # If host and port are given, override values from the URI. 466 host = kwargs.pop("host", host) 467 port = kwargs.pop("port", port) 468 create_connection = functools.partial( 469 loop.create_connection, factory, host, port, **kwargs 470 ) 471 else: 472 create_connection = functools.partial( 473 loop.create_unix_connection, factory, path, **kwargs 474 ) 475 476 # This is a coroutine function. 477 self._create_connection = create_connection 478 self._wsuri = wsuri 479 480 def handle_redirect(self, uri: str) -> None: 481 # Update the state of this instance to connect to a new URI. 482 old_wsuri = self._wsuri 483 new_wsuri = parse_uri(uri) 484 485 # Forbid TLS downgrade. 486 if old_wsuri.secure and not new_wsuri.secure: 487 raise SecurityError("redirect from WSS to WS") 488 489 same_origin = ( 490 old_wsuri.host == new_wsuri.host and old_wsuri.port == new_wsuri.port 491 ) 492 493 # Rewrite the host and port arguments for cross-origin redirects. 494 # This preserves connection overrides with the host and port 495 # arguments if the redirect points to the same host and port. 496 if not same_origin: 497 # Replace the host and port argument passed to the protocol factory. 498 factory = self._create_connection.args[0] 499 factory = functools.partial( 500 factory.func, 501 *factory.args, 502 **dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port), 503 ) 504 # Replace the host and port argument passed to create_connection. 505 self._create_connection = functools.partial( 506 self._create_connection.func, 507 *(factory, new_wsuri.host, new_wsuri.port), 508 **self._create_connection.keywords, 509 ) 510 511 # Set the new WebSocket URI. This suffices for same-origin redirects. 512 self._wsuri = new_wsuri 513 514 # async with connect(...) 515 516 async def __aenter__(self) -> WebSocketClientProtocol: 517 return await self 518 519 async def __aexit__( 520 self, 521 exc_type: Optional[Type[BaseException]], 522 exc_value: Optional[BaseException], 523 traceback: Optional[TracebackType], 524 ) -> None: 525 await self.ws_client.close() 526 527 # await connect(...) 528 529 def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]: 530 # Create a suitable iterator by calling __await__ on a coroutine. 531 return self.__await_impl__().__await__() 532 533 async def __await_impl__(self) -> WebSocketClientProtocol: 534 for redirects in range(self.MAX_REDIRECTS_ALLOWED): 535 transport, protocol = await self._create_connection() 536 # https://github.com/python/typeshed/pull/2756 537 transport = cast(asyncio.Transport, transport) 538 protocol = cast(WebSocketClientProtocol, protocol) 539 540 try: 541 try: 542 await protocol.handshake( 543 self._wsuri, 544 origin=protocol.origin, 545 available_extensions=protocol.available_extensions, 546 available_subprotocols=protocol.available_subprotocols, 547 extra_headers=protocol.extra_headers, 548 ) 549 except Exception: 550 protocol.fail_connection() 551 await protocol.wait_closed() 552 raise 553 else: 554 self.ws_client = protocol 555 return protocol 556 except RedirectHandshake as exc: 557 self.handle_redirect(exc.uri) 558 else: 559 raise SecurityError("too many redirects") 560 561 # yield from connect(...) 562 563 __iter__ = __await__ 564 565 566connect = Connect 567 568 569def unix_connect(path: str, uri: str = "ws://localhost/", **kwargs: Any) -> Connect: 570 """ 571 Similar to :func:`connect`, but for connecting to a Unix socket. 572 573 This function calls the event loop's 574 :meth:`~asyncio.loop.create_unix_connection` method. 575 576 It is only available on Unix. 577 578 It's mainly useful for debugging servers listening on Unix sockets. 579 580 :param path: file system path to the Unix socket 581 :param uri: WebSocket URI 582 583 """ 584 return connect(uri=uri, path=path, **kwargs) 585