1"""
2:mod:`websockets.client` defines the WebSocket client APIs.
3
4"""
5
6import asyncio
7import collections.abc
8import functools
9import logging
10import warnings
11from types import TracebackType
12from typing import Any, Generator, List, Optional, Sequence, Tuple, Type, cast
13
14from .exceptions import (
15    InvalidHandshake,
16    InvalidHeader,
17    InvalidMessage,
18    InvalidStatusCode,
19    NegotiationError,
20    RedirectHandshake,
21    SecurityError,
22)
23from .extensions.base import ClientExtensionFactory, Extension
24from .extensions.permessage_deflate import ClientPerMessageDeflateFactory
25from .handshake import build_request, check_response
26from .headers import (
27    build_authorization_basic,
28    build_extension,
29    build_subprotocol,
30    parse_extension,
31    parse_subprotocol,
32)
33from .http import USER_AGENT, Headers, HeadersLike, read_response
34from .protocol import WebSocketCommonProtocol
35from .typing import ExtensionHeader, Origin, Subprotocol
36from .uri import WebSocketURI, parse_uri
37
38
39__all__ = ["connect", "unix_connect", "WebSocketClientProtocol"]
40
41logger = logging.getLogger(__name__)
42
43
44class WebSocketClientProtocol(WebSocketCommonProtocol):
45    """
46    :class:`~asyncio.Protocol` subclass implementing a WebSocket client.
47
48    This class inherits most of its methods from
49    :class:`~websockets.protocol.WebSocketCommonProtocol`.
50
51    """
52
53    is_client = True
54    side = "client"
55
56    def __init__(
57        self,
58        *,
59        origin: Optional[Origin] = None,
60        extensions: Optional[Sequence[ClientExtensionFactory]] = None,
61        subprotocols: Optional[Sequence[Subprotocol]] = None,
62        extra_headers: Optional[HeadersLike] = None,
63        **kwargs: Any,
64    ) -> None:
65        self.origin = origin
66        self.available_extensions = extensions
67        self.available_subprotocols = subprotocols
68        self.extra_headers = extra_headers
69        super().__init__(**kwargs)
70
71    def write_http_request(self, path: str, headers: Headers) -> None:
72        """
73        Write request line and headers to the HTTP request.
74
75        """
76        self.path = path
77        self.request_headers = headers
78
79        logger.debug("%s > GET %s HTTP/1.1", self.side, path)
80        logger.debug("%s > %r", self.side, headers)
81
82        # Since the path and headers only contain ASCII characters,
83        # we can keep this simple.
84        request = f"GET {path} HTTP/1.1\r\n"
85        request += str(headers)
86
87        self.transport.write(request.encode())
88
89    async def read_http_response(self) -> Tuple[int, Headers]:
90        """
91        Read status line and headers from the HTTP response.
92
93        If the response contains a body, it may be read from ``self.reader``
94        after this coroutine returns.
95
96        :raises ~websockets.exceptions.InvalidMessage: if the HTTP message is
97            malformed or isn't an HTTP/1.1 GET response
98
99        """
100        try:
101            status_code, reason, headers = await read_response(self.reader)
102        except Exception as exc:
103            raise InvalidMessage("did not receive a valid HTTP response") from exc
104
105        logger.debug("%s < HTTP/1.1 %d %s", self.side, status_code, reason)
106        logger.debug("%s < %r", self.side, headers)
107
108        self.response_headers = headers
109
110        return status_code, self.response_headers
111
112    @staticmethod
113    def process_extensions(
114        headers: Headers,
115        available_extensions: Optional[Sequence[ClientExtensionFactory]],
116    ) -> List[Extension]:
117        """
118        Handle the Sec-WebSocket-Extensions HTTP response header.
119
120        Check that each extension is supported, as well as its parameters.
121
122        Return the list of accepted extensions.
123
124        Raise :exc:`~websockets.exceptions.InvalidHandshake` to abort the
125        connection.
126
127        :rfc:`6455` leaves the rules up to the specification of each
128        :extension.
129
130        To provide this level of flexibility, for each extension accepted by
131        the server, we check for a match with each extension available in the
132        client configuration. If no match is found, an exception is raised.
133
134        If several variants of the same extension are accepted by the server,
135        it may be configured severel times, which won't make sense in general.
136        Extensions must implement their own requirements. For this purpose,
137        the list of previously accepted extensions is provided.
138
139        Other requirements, for example related to mandatory extensions or the
140        order of extensions, may be implemented by overriding this method.
141
142        """
143        accepted_extensions: List[Extension] = []
144
145        header_values = headers.get_all("Sec-WebSocket-Extensions")
146
147        if header_values:
148
149            if available_extensions is None:
150                raise InvalidHandshake("no extensions supported")
151
152            parsed_header_values: List[ExtensionHeader] = sum(
153                [parse_extension(header_value) for header_value in header_values], []
154            )
155
156            for name, response_params in parsed_header_values:
157
158                for extension_factory in available_extensions:
159
160                    # Skip non-matching extensions based on their name.
161                    if extension_factory.name != name:
162                        continue
163
164                    # Skip non-matching extensions based on their params.
165                    try:
166                        extension = extension_factory.process_response_params(
167                            response_params, accepted_extensions
168                        )
169                    except NegotiationError:
170                        continue
171
172                    # Add matching extension to the final list.
173                    accepted_extensions.append(extension)
174
175                    # Break out of the loop once we have a match.
176                    break
177
178                # If we didn't break from the loop, no extension in our list
179                # matched what the server sent. Fail the connection.
180                else:
181                    raise NegotiationError(
182                        f"Unsupported extension: "
183                        f"name = {name}, params = {response_params}"
184                    )
185
186        return accepted_extensions
187
188    @staticmethod
189    def process_subprotocol(
190        headers: Headers, available_subprotocols: Optional[Sequence[Subprotocol]]
191    ) -> Optional[Subprotocol]:
192        """
193        Handle the Sec-WebSocket-Protocol HTTP response header.
194
195        Check that it contains exactly one supported subprotocol.
196
197        Return the selected subprotocol.
198
199        """
200        subprotocol: Optional[Subprotocol] = None
201
202        header_values = headers.get_all("Sec-WebSocket-Protocol")
203
204        if header_values:
205
206            if available_subprotocols is None:
207                raise InvalidHandshake("no subprotocols supported")
208
209            parsed_header_values: Sequence[Subprotocol] = sum(
210                [parse_subprotocol(header_value) for header_value in header_values], []
211            )
212
213            if len(parsed_header_values) > 1:
214                subprotocols = ", ".join(parsed_header_values)
215                raise InvalidHandshake(f"multiple subprotocols: {subprotocols}")
216
217            subprotocol = parsed_header_values[0]
218
219            if subprotocol not in available_subprotocols:
220                raise NegotiationError(f"unsupported subprotocol: {subprotocol}")
221
222        return subprotocol
223
224    async def handshake(
225        self,
226        wsuri: WebSocketURI,
227        origin: Optional[Origin] = None,
228        available_extensions: Optional[Sequence[ClientExtensionFactory]] = None,
229        available_subprotocols: Optional[Sequence[Subprotocol]] = None,
230        extra_headers: Optional[HeadersLike] = None,
231    ) -> None:
232        """
233        Perform the client side of the opening handshake.
234
235        :param origin: sets the Origin HTTP header
236        :param available_extensions: list of supported extensions in the order
237            in which they should be used
238        :param available_subprotocols: list of supported subprotocols in order
239            of decreasing preference
240        :param extra_headers: sets additional HTTP request headers; it must be
241            a :class:`~websockets.http.Headers` instance, a
242            :class:`~collections.abc.Mapping`, or an iterable of ``(name,
243            value)`` pairs
244        :raises ~websockets.exceptions.InvalidHandshake: if the handshake
245            fails
246
247        """
248        request_headers = Headers()
249
250        if wsuri.port == (443 if wsuri.secure else 80):  # pragma: no cover
251            request_headers["Host"] = wsuri.host
252        else:
253            request_headers["Host"] = f"{wsuri.host}:{wsuri.port}"
254
255        if wsuri.user_info:
256            request_headers["Authorization"] = build_authorization_basic(
257                *wsuri.user_info
258            )
259
260        if origin is not None:
261            request_headers["Origin"] = origin
262
263        key = build_request(request_headers)
264
265        if available_extensions is not None:
266            extensions_header = build_extension(
267                [
268                    (extension_factory.name, extension_factory.get_request_params())
269                    for extension_factory in available_extensions
270                ]
271            )
272            request_headers["Sec-WebSocket-Extensions"] = extensions_header
273
274        if available_subprotocols is not None:
275            protocol_header = build_subprotocol(available_subprotocols)
276            request_headers["Sec-WebSocket-Protocol"] = protocol_header
277
278        if extra_headers is not None:
279            if isinstance(extra_headers, Headers):
280                extra_headers = extra_headers.raw_items()
281            elif isinstance(extra_headers, collections.abc.Mapping):
282                extra_headers = extra_headers.items()
283            for name, value in extra_headers:
284                request_headers[name] = value
285
286        request_headers.setdefault("User-Agent", USER_AGENT)
287
288        self.write_http_request(wsuri.resource_name, request_headers)
289
290        status_code, response_headers = await self.read_http_response()
291        if status_code in (301, 302, 303, 307, 308):
292            if "Location" not in response_headers:
293                raise InvalidHeader("Location")
294            raise RedirectHandshake(response_headers["Location"])
295        elif status_code != 101:
296            raise InvalidStatusCode(status_code)
297
298        check_response(response_headers, key)
299
300        self.extensions = self.process_extensions(
301            response_headers, available_extensions
302        )
303
304        self.subprotocol = self.process_subprotocol(
305            response_headers, available_subprotocols
306        )
307
308        self.connection_open()
309
310
311class Connect:
312    """
313    Connect to the WebSocket server at the given ``uri``.
314
315    Awaiting :func:`connect` yields a :class:`WebSocketClientProtocol` which
316    can then be used to send and receive messages.
317
318    :func:`connect` can also be used as a asynchronous context manager. In
319    that case, the connection is closed when exiting the context.
320
321    :func:`connect` is a wrapper around the event loop's
322    :meth:`~asyncio.loop.create_connection` method. Unknown keyword arguments
323    are passed to :meth:`~asyncio.loop.create_connection`.
324
325    For example, you can set the ``ssl`` keyword argument to a
326    :class:`~ssl.SSLContext` to enforce some TLS settings. When connecting to
327    a ``wss://`` URI, if this argument isn't provided explicitly,
328    :func:`ssl.create_default_context` is called to create a context.
329
330    You can connect to a different host and port from those found in ``uri``
331    by setting ``host`` and ``port`` keyword arguments. This only changes the
332    destination of the TCP connection. The host name from ``uri`` is still
333    used in the TLS handshake for secure connections and in the ``Host`` HTTP
334    header.
335
336    The ``create_protocol`` parameter allows customizing the
337    :class:`~asyncio.Protocol` that manages the connection. It should be a
338    callable or class accepting the same arguments as
339    :class:`WebSocketClientProtocol` and returning an instance of
340    :class:`WebSocketClientProtocol` or a subclass. It defaults to
341    :class:`WebSocketClientProtocol`.
342
343    The behavior of ``ping_interval``, ``ping_timeout``, ``close_timeout``,
344    ``max_size``, ``max_queue``, ``read_limit``, and ``write_limit`` is
345    described in :class:`~websockets.protocol.WebSocketCommonProtocol`.
346
347    :func:`connect` also accepts the following optional arguments:
348
349    * ``compression`` is a shortcut to configure compression extensions;
350      by default it enables the "permessage-deflate" extension; set it to
351      ``None`` to disable compression
352    * ``origin`` sets the Origin HTTP header
353    * ``extensions`` is a list of supported extensions in order of
354      decreasing preference
355    * ``subprotocols`` is a list of supported subprotocols in order of
356      decreasing preference
357    * ``extra_headers`` sets additional HTTP request headers; it can be a
358      :class:`~websockets.http.Headers` instance, a
359      :class:`~collections.abc.Mapping`, or an iterable of ``(name, value)``
360      pairs
361
362    :raises ~websockets.uri.InvalidURI: if ``uri`` is invalid
363    :raises ~websockets.handshake.InvalidHandshake: if the opening handshake
364        fails
365
366    """
367
368    MAX_REDIRECTS_ALLOWED = 10
369
370    def __init__(
371        self,
372        uri: str,
373        *,
374        path: Optional[str] = None,
375        create_protocol: Optional[Type[WebSocketClientProtocol]] = None,
376        ping_interval: float = 20,
377        ping_timeout: float = 20,
378        close_timeout: Optional[float] = None,
379        max_size: int = 2 ** 20,
380        max_queue: int = 2 ** 5,
381        read_limit: int = 2 ** 16,
382        write_limit: int = 2 ** 16,
383        loop: Optional[asyncio.AbstractEventLoop] = None,
384        legacy_recv: bool = False,
385        klass: Optional[Type[WebSocketClientProtocol]] = None,
386        timeout: Optional[float] = None,
387        compression: Optional[str] = "deflate",
388        origin: Optional[Origin] = None,
389        extensions: Optional[Sequence[ClientExtensionFactory]] = None,
390        subprotocols: Optional[Sequence[Subprotocol]] = None,
391        extra_headers: Optional[HeadersLike] = None,
392        **kwargs: Any,
393    ) -> None:
394        # Backwards compatibility: close_timeout used to be called timeout.
395        if timeout is None:
396            timeout = 10
397        else:
398            warnings.warn("rename timeout to close_timeout", DeprecationWarning)
399        # If both are specified, timeout is ignored.
400        if close_timeout is None:
401            close_timeout = timeout
402
403        # Backwards compatibility: create_protocol used to be called klass.
404        if klass is None:
405            klass = WebSocketClientProtocol
406        else:
407            warnings.warn("rename klass to create_protocol", DeprecationWarning)
408        # If both are specified, klass is ignored.
409        if create_protocol is None:
410            create_protocol = klass
411
412        if loop is None:
413            loop = asyncio.get_event_loop()
414
415        wsuri = parse_uri(uri)
416        if wsuri.secure:
417            kwargs.setdefault("ssl", True)
418        elif kwargs.get("ssl") is not None:
419            raise ValueError(
420                "connect() received a ssl argument for a ws:// URI, "
421                "use a wss:// URI to enable TLS"
422            )
423
424        if compression == "deflate":
425            if extensions is None:
426                extensions = []
427            if not any(
428                extension_factory.name == ClientPerMessageDeflateFactory.name
429                for extension_factory in extensions
430            ):
431                extensions = list(extensions) + [
432                    ClientPerMessageDeflateFactory(client_max_window_bits=True)
433                ]
434        elif compression is not None:
435            raise ValueError(f"unsupported compression: {compression}")
436
437        factory = functools.partial(
438            create_protocol,
439            ping_interval=ping_interval,
440            ping_timeout=ping_timeout,
441            close_timeout=close_timeout,
442            max_size=max_size,
443            max_queue=max_queue,
444            read_limit=read_limit,
445            write_limit=write_limit,
446            loop=loop,
447            host=wsuri.host,
448            port=wsuri.port,
449            secure=wsuri.secure,
450            legacy_recv=legacy_recv,
451            origin=origin,
452            extensions=extensions,
453            subprotocols=subprotocols,
454            extra_headers=extra_headers,
455        )
456
457        if path is None:
458            host: Optional[str]
459            port: Optional[int]
460            if kwargs.get("sock") is None:
461                host, port = wsuri.host, wsuri.port
462            else:
463                # If sock is given, host and port shouldn't be specified.
464                host, port = None, None
465            # If host and port are given, override values from the URI.
466            host = kwargs.pop("host", host)
467            port = kwargs.pop("port", port)
468            create_connection = functools.partial(
469                loop.create_connection, factory, host, port, **kwargs
470            )
471        else:
472            create_connection = functools.partial(
473                loop.create_unix_connection, factory, path, **kwargs
474            )
475
476        # This is a coroutine function.
477        self._create_connection = create_connection
478        self._wsuri = wsuri
479
480    def handle_redirect(self, uri: str) -> None:
481        # Update the state of this instance to connect to a new URI.
482        old_wsuri = self._wsuri
483        new_wsuri = parse_uri(uri)
484
485        # Forbid TLS downgrade.
486        if old_wsuri.secure and not new_wsuri.secure:
487            raise SecurityError("redirect from WSS to WS")
488
489        same_origin = (
490            old_wsuri.host == new_wsuri.host and old_wsuri.port == new_wsuri.port
491        )
492
493        # Rewrite the host and port arguments for cross-origin redirects.
494        # This preserves connection overrides with the host and port
495        # arguments if the redirect points to the same host and port.
496        if not same_origin:
497            # Replace the host and port argument passed to the protocol factory.
498            factory = self._create_connection.args[0]
499            factory = functools.partial(
500                factory.func,
501                *factory.args,
502                **dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port),
503            )
504            # Replace the host and port argument passed to create_connection.
505            self._create_connection = functools.partial(
506                self._create_connection.func,
507                *(factory, new_wsuri.host, new_wsuri.port),
508                **self._create_connection.keywords,
509            )
510
511        # Set the new WebSocket URI. This suffices for same-origin redirects.
512        self._wsuri = new_wsuri
513
514    # async with connect(...)
515
516    async def __aenter__(self) -> WebSocketClientProtocol:
517        return await self
518
519    async def __aexit__(
520        self,
521        exc_type: Optional[Type[BaseException]],
522        exc_value: Optional[BaseException],
523        traceback: Optional[TracebackType],
524    ) -> None:
525        await self.ws_client.close()
526
527    # await connect(...)
528
529    def __await__(self) -> Generator[Any, None, WebSocketClientProtocol]:
530        # Create a suitable iterator by calling __await__ on a coroutine.
531        return self.__await_impl__().__await__()
532
533    async def __await_impl__(self) -> WebSocketClientProtocol:
534        for redirects in range(self.MAX_REDIRECTS_ALLOWED):
535            transport, protocol = await self._create_connection()
536            # https://github.com/python/typeshed/pull/2756
537            transport = cast(asyncio.Transport, transport)
538            protocol = cast(WebSocketClientProtocol, protocol)
539
540            try:
541                try:
542                    await protocol.handshake(
543                        self._wsuri,
544                        origin=protocol.origin,
545                        available_extensions=protocol.available_extensions,
546                        available_subprotocols=protocol.available_subprotocols,
547                        extra_headers=protocol.extra_headers,
548                    )
549                except Exception:
550                    protocol.fail_connection()
551                    await protocol.wait_closed()
552                    raise
553                else:
554                    self.ws_client = protocol
555                    return protocol
556            except RedirectHandshake as exc:
557                self.handle_redirect(exc.uri)
558        else:
559            raise SecurityError("too many redirects")
560
561    # yield from connect(...)
562
563    __iter__ = __await__
564
565
566connect = Connect
567
568
569def unix_connect(path: str, uri: str = "ws://localhost/", **kwargs: Any) -> Connect:
570    """
571    Similar to :func:`connect`, but for connecting to a Unix socket.
572
573    This function calls the event loop's
574    :meth:`~asyncio.loop.create_unix_connection` method.
575
576    It is only available on Unix.
577
578    It's mainly useful for debugging servers listening on Unix sockets.
579
580    :param path: file system path to the Unix socket
581    :param uri: WebSocket URI
582
583    """
584    return connect(uri=uri, path=path, **kwargs)
585