1"""
2wsproto/handshake
3~~~~~~~~~~~~~~~~~~
4
5An implementation of WebSocket handshakes.
6"""
7from collections import deque
8from typing import (
9    cast,
10    Deque,
11    Dict,
12    Generator,
13    Iterable,
14    List,
15    Optional,
16    Sequence,
17    Union,
18)
19
20import h11
21
22from .connection import Connection, ConnectionState, ConnectionType
23from .events import AcceptConnection, Event, RejectConnection, RejectData, Request
24from .extensions import Extension
25from .typing import Headers
26from .utilities import (
27    generate_accept_token,
28    generate_nonce,
29    LocalProtocolError,
30    normed_header_dict,
31    RemoteProtocolError,
32    split_comma_header,
33)
34
35# RFC6455, Section 4.2.1/6 - Reading the Client's Opening Handshake
36WEBSOCKET_VERSION = b"13"
37
38
39class H11Handshake:
40    """A Handshake implementation for HTTP/1.1 connections."""
41
42    def __init__(self, connection_type: ConnectionType) -> None:
43        self.client = connection_type is ConnectionType.CLIENT
44        self._state = ConnectionState.CONNECTING
45
46        if self.client:
47            self._h11_connection = h11.Connection(h11.CLIENT)
48        else:
49            self._h11_connection = h11.Connection(h11.SERVER)
50
51        self._connection: Optional[Connection] = None
52        self._events: Deque[Event] = deque()
53        self._initiating_request: Optional[Request] = None
54        self._nonce: Optional[bytes] = None
55
56    @property
57    def state(self) -> ConnectionState:
58        return self._state
59
60    @property
61    def connection(self) -> Optional[Connection]:
62        """Return the established connection.
63
64        This will either return the connection or raise a
65        LocalProtocolError if the connection has not yet been
66        established.
67
68        :rtype: h11.Connection
69        """
70        return self._connection
71
72    def initiate_upgrade_connection(self, headers: Headers, path: str) -> None:
73        """Initiate an upgrade connection.
74
75        This should be used if the request has already be received and
76        parsed.
77
78        :param list headers: HTTP headers represented as a list of 2-tuples.
79        :param str path: A URL path.
80        """
81        if self.client:
82            raise LocalProtocolError(
83                "Cannot initiate an upgrade connection when acting as the client"
84            )
85        upgrade_request = h11.Request(method=b"GET", target=path, headers=headers)
86        h11_client = h11.Connection(h11.CLIENT)
87        self.receive_data(h11_client.send(upgrade_request))
88
89    def send(self, event: Event) -> bytes:
90        """Send an event to the remote.
91
92        This will return the bytes to send based on the event or raise
93        a LocalProtocolError if the event is not valid given the
94        state.
95
96        :returns: Data to send to the WebSocket peer.
97        :rtype: bytes
98        """
99        data = b""
100        if isinstance(event, Request):
101            data += self._initiate_connection(event)
102        elif isinstance(event, AcceptConnection):
103            data += self._accept(event)
104        elif isinstance(event, RejectConnection):
105            data += self._reject(event)
106        elif isinstance(event, RejectData):
107            data += self._send_reject_data(event)
108        else:
109            raise LocalProtocolError(
110                f"Event {event} cannot be sent during the handshake"
111            )
112        return data
113
114    def receive_data(self, data: Optional[bytes]) -> None:
115        """Receive data from the remote.
116
117        A list of events that the remote peer triggered by sending
118        this data can be retrieved with :meth:`events`.
119
120        :param bytes data: Data received from the WebSocket peer.
121        """
122        self._h11_connection.receive_data(data)
123        while True:
124            try:
125                event = self._h11_connection.next_event()
126            except h11.RemoteProtocolError:
127                raise RemoteProtocolError(
128                    "Bad HTTP message", event_hint=RejectConnection()
129                )
130            if (
131                isinstance(event, h11.ConnectionClosed)
132                or event is h11.NEED_DATA
133                or event is h11.PAUSED
134            ):
135                break
136
137            if self.client:
138                if isinstance(event, h11.InformationalResponse):
139                    if event.status_code == 101:
140                        self._events.append(self._establish_client_connection(event))
141                    else:
142                        self._events.append(
143                            RejectConnection(
144                                headers=event.headers,
145                                status_code=event.status_code,
146                                has_body=False,
147                            )
148                        )
149                        self._state = ConnectionState.CLOSED
150                elif isinstance(event, h11.Response):
151                    self._state = ConnectionState.REJECTING
152                    self._events.append(
153                        RejectConnection(
154                            headers=event.headers,
155                            status_code=event.status_code,
156                            has_body=True,
157                        )
158                    )
159                elif isinstance(event, h11.Data):
160                    self._events.append(
161                        RejectData(data=event.data, body_finished=False)
162                    )
163                elif isinstance(event, h11.EndOfMessage):
164                    self._events.append(RejectData(data=b"", body_finished=True))
165                    self._state = ConnectionState.CLOSED
166            else:
167                if isinstance(event, h11.Request):
168                    self._events.append(self._process_connection_request(event))
169
170    def events(self) -> Generator[Event, None, None]:
171        """Return a generator that provides any events that have been generated
172        by protocol activity.
173
174        :returns: a generator that yields H11 events.
175        """
176        while self._events:
177            yield self._events.popleft()
178
179    ############ Server mode methods
180
181    def _process_connection_request(  # noqa: MC0001
182        self, event: h11.Request
183    ) -> Request:
184        if event.method != b"GET":
185            raise RemoteProtocolError(
186                "Request method must be GET", event_hint=RejectConnection()
187            )
188        connection_tokens = None
189        extensions: List[str] = []
190        host = None
191        key = None
192        subprotocols: List[str] = []
193        upgrade = b""
194        version = None
195        headers: Headers = []
196        for name, value in event.headers:
197            name = name.lower()
198            if name == b"connection":
199                connection_tokens = split_comma_header(value)
200            elif name == b"host":
201                host = value.decode("ascii")
202                continue  # Skip appending to headers
203            elif name == b"sec-websocket-extensions":
204                extensions = split_comma_header(value)
205                continue  # Skip appending to headers
206            elif name == b"sec-websocket-key":
207                key = value
208            elif name == b"sec-websocket-protocol":
209                subprotocols = split_comma_header(value)
210                continue  # Skip appending to headers
211            elif name == b"sec-websocket-version":
212                version = value
213            elif name == b"upgrade":
214                upgrade = value
215            headers.append((name, value))
216        if connection_tokens is None or not any(
217            token.lower() == "upgrade" for token in connection_tokens
218        ):
219            raise RemoteProtocolError(
220                "Missing header, 'Connection: Upgrade'", event_hint=RejectConnection()
221            )
222        if version != WEBSOCKET_VERSION:
223            raise RemoteProtocolError(
224                "Missing header, 'Sec-WebSocket-Version'",
225                event_hint=RejectConnection(
226                    headers=[(b"Sec-WebSocket-Version", WEBSOCKET_VERSION)],
227                    status_code=426 if version else 400,
228                ),
229            )
230        if key is None:
231            raise RemoteProtocolError(
232                "Missing header, 'Sec-WebSocket-Key'", event_hint=RejectConnection()
233            )
234        if upgrade.lower() != b"websocket":
235            raise RemoteProtocolError(
236                "Missing header, 'Upgrade: WebSocket'", event_hint=RejectConnection()
237            )
238        if host is None:
239            raise RemoteProtocolError(
240                "Missing header, 'Host'", event_hint=RejectConnection()
241            )
242
243        self._initiating_request = Request(
244            extensions=extensions,
245            extra_headers=headers,
246            host=host,
247            subprotocols=subprotocols,
248            target=event.target.decode("ascii"),
249        )
250        return self._initiating_request
251
252    def _accept(self, event: AcceptConnection) -> bytes:
253        # _accept is always called after _process_connection_request.
254        assert self._initiating_request is not None
255        request_headers = normed_header_dict(self._initiating_request.extra_headers)
256
257        nonce = request_headers[b"sec-websocket-key"]
258        accept_token = generate_accept_token(nonce)
259
260        headers = [
261            (b"Upgrade", b"WebSocket"),
262            (b"Connection", b"Upgrade"),
263            (b"Sec-WebSocket-Accept", accept_token),
264        ]
265
266        if event.subprotocol is not None:
267            if event.subprotocol not in self._initiating_request.subprotocols:
268                raise LocalProtocolError(f"unexpected subprotocol {event.subprotocol}")
269            headers.append(
270                (b"Sec-WebSocket-Protocol", event.subprotocol.encode("ascii"))
271            )
272
273        if event.extensions:
274            accepts = server_extensions_handshake(
275                cast(Sequence[str], self._initiating_request.extensions),
276                event.extensions,
277            )
278            if accepts:
279                headers.append((b"Sec-WebSocket-Extensions", accepts))
280
281        response = h11.InformationalResponse(
282            status_code=101, headers=headers + event.extra_headers
283        )
284        self._connection = Connection(
285            ConnectionType.CLIENT if self.client else ConnectionType.SERVER,
286            event.extensions,
287        )
288        self._state = ConnectionState.OPEN
289        return self._h11_connection.send(response)
290
291    def _reject(self, event: RejectConnection) -> bytes:
292        if self.state != ConnectionState.CONNECTING:
293            raise LocalProtocolError(
294                "Connection cannot be rejected in state %s" % self.state
295            )
296
297        headers = event.headers
298        if not event.has_body:
299            headers.append((b"content-length", b"0"))
300        response = h11.Response(status_code=event.status_code, headers=headers)
301        data = self._h11_connection.send(response)
302        self._state = ConnectionState.REJECTING
303        if not event.has_body:
304            data += self._h11_connection.send(h11.EndOfMessage())
305            self._state = ConnectionState.CLOSED
306        return data
307
308    def _send_reject_data(self, event: RejectData) -> bytes:
309        if self.state != ConnectionState.REJECTING:
310            raise LocalProtocolError(
311                f"Cannot send rejection data in state {self.state}"
312            )
313
314        data = self._h11_connection.send(h11.Data(data=event.data))
315        if event.body_finished:
316            data += self._h11_connection.send(h11.EndOfMessage())
317            self._state = ConnectionState.CLOSED
318        return data
319
320    ############ Client mode methods
321
322    def _initiate_connection(self, request: Request) -> bytes:
323        self._initiating_request = request
324        self._nonce = generate_nonce()
325
326        headers = [
327            (b"Host", request.host.encode("ascii")),
328            (b"Upgrade", b"WebSocket"),
329            (b"Connection", b"Upgrade"),
330            (b"Sec-WebSocket-Key", self._nonce),
331            (b"Sec-WebSocket-Version", WEBSOCKET_VERSION),
332        ]
333
334        if request.subprotocols:
335            headers.append(
336                (
337                    b"Sec-WebSocket-Protocol",
338                    (", ".join(request.subprotocols)).encode("ascii"),
339                )
340            )
341
342        if request.extensions:
343            offers: Dict[str, Union[str, bool]] = {}
344            for e in request.extensions:
345                assert isinstance(e, Extension)
346                offers[e.name] = e.offer()
347            extensions = []
348            for name, params in offers.items():
349                bname = name.encode("ascii")
350                if isinstance(params, bool):
351                    if params:
352                        extensions.append(bname)
353                else:
354                    extensions.append(b"%s; %s" % (bname, params.encode("ascii")))
355            if extensions:
356                headers.append((b"Sec-WebSocket-Extensions", b", ".join(extensions)))
357
358        upgrade = h11.Request(
359            method=b"GET",
360            target=request.target.encode("ascii"),
361            headers=headers + request.extra_headers,
362        )
363        return self._h11_connection.send(upgrade)
364
365    def _establish_client_connection(
366        self, event: h11.InformationalResponse
367    ) -> AcceptConnection:  # noqa: MC0001
368        # _establish_client_connection is always called after _initiate_connection.
369        assert self._initiating_request is not None
370        assert self._nonce is not None
371
372        accept = None
373        connection_tokens = None
374        accepts: List[str] = []
375        subprotocol = None
376        upgrade = b""
377        headers: Headers = []
378        for name, value in event.headers:
379            name = name.lower()
380            if name == b"connection":
381                connection_tokens = split_comma_header(value)
382                continue  # Skip appending to headers
383            elif name == b"sec-websocket-extensions":
384                accepts = split_comma_header(value)
385                continue  # Skip appending to headers
386            elif name == b"sec-websocket-accept":
387                accept = value
388                continue  # Skip appending to headers
389            elif name == b"sec-websocket-protocol":
390                subprotocol = value
391                continue  # Skip appending to headers
392            elif name == b"upgrade":
393                upgrade = value
394                continue  # Skip appending to headers
395            headers.append((name, value))
396
397        if connection_tokens is None or not any(
398            token.lower() == "upgrade" for token in connection_tokens
399        ):
400            raise RemoteProtocolError(
401                "Missing header, 'Connection: Upgrade'", event_hint=RejectConnection()
402            )
403        if upgrade.lower() != b"websocket":
404            raise RemoteProtocolError(
405                "Missing header, 'Upgrade: WebSocket'", event_hint=RejectConnection()
406            )
407        accept_token = generate_accept_token(self._nonce)
408        if accept != accept_token:
409            raise RemoteProtocolError("Bad accept token", event_hint=RejectConnection())
410        if subprotocol is not None:
411            subprotocol = subprotocol.decode("ascii")
412            if subprotocol not in self._initiating_request.subprotocols:
413                raise RemoteProtocolError(
414                    f"unrecognized subprotocol {subprotocol}",
415                    event_hint=RejectConnection(),
416                )
417        extensions = client_extensions_handshake(
418            accepts, cast(Sequence[Extension], self._initiating_request.extensions)
419        )
420
421        self._connection = Connection(
422            ConnectionType.CLIENT if self.client else ConnectionType.SERVER,
423            extensions,
424            self._h11_connection.trailing_data[0],
425        )
426        self._state = ConnectionState.OPEN
427        return AcceptConnection(
428            extensions=extensions, extra_headers=headers, subprotocol=subprotocol
429        )
430
431    def __repr__(self) -> str:
432        return "{}(client={}, state={})".format(
433            self.__class__.__name__, self.client, self.state
434        )
435
436
437def server_extensions_handshake(
438    requested: Iterable[str], supported: List[Extension]
439) -> Optional[bytes]:
440    """Agree on the extensions to use returning an appropriate header value.
441
442    This returns None if there are no agreed extensions
443    """
444    accepts: Dict[str, Union[bool, bytes]] = {}
445    for offer in requested:
446        name = offer.split(";", 1)[0].strip()
447        for extension in supported:
448            if extension.name == name:
449                accept = extension.accept(offer)
450                if isinstance(accept, bool):
451                    if accept:
452                        accepts[extension.name] = True
453                elif accept is not None:
454                    accepts[extension.name] = accept.encode("ascii")
455
456    if accepts:
457        extensions: List[bytes] = []
458        for name, params in accepts.items():
459            name_bytes = name.encode("ascii")
460            if isinstance(params, bool):
461                assert params
462                extensions.append(name_bytes)
463            else:
464                if params == b"":
465                    extensions.append(b"%s" % (name_bytes))
466                else:
467                    extensions.append(b"%s; %s" % (name_bytes, params))
468        return b", ".join(extensions)
469
470    return None
471
472
473def client_extensions_handshake(
474    accepted: Iterable[str], supported: Sequence[Extension]
475) -> List[Extension]:
476    # This raises RemoteProtocolError is the accepted extension is not
477    # supported.
478    extensions = []
479    for accept in accepted:
480        name = accept.split(";", 1)[0].strip()
481        for extension in supported:
482            if extension.name == name:
483                extension.finalize(accept)
484                extensions.append(extension)
485                break
486        else:
487            raise RemoteProtocolError(
488                f"unrecognized extension {name}", event_hint=RejectConnection()
489            )
490    return extensions
491