1"""
2wsproto/extensions
3~~~~~~~~~~~~~~~~~~
4
5WebSocket extensions.
6"""
7
8import zlib
9from typing import Optional, Tuple, Union
10
11from .frame_protocol import CloseReason, FrameDecoder, FrameProtocol, Opcode, RsvBits
12
13
14class Extension:
15    name: str
16
17    def enabled(self) -> bool:
18        return False
19
20    def offer(self) -> Union[bool, str]:
21        pass
22
23    def accept(self, offer: str) -> Optional[Union[bool, str]]:
24        pass
25
26    def finalize(self, offer: str) -> None:
27        pass
28
29    def frame_inbound_header(
30        self,
31        proto: Union[FrameDecoder, FrameProtocol],
32        opcode: Opcode,
33        rsv: RsvBits,
34        payload_length: int,
35    ) -> Union[CloseReason, RsvBits]:
36        return RsvBits(False, False, False)
37
38    def frame_inbound_payload_data(
39        self, proto: Union[FrameDecoder, FrameProtocol], data: bytes
40    ) -> Union[bytes, CloseReason]:
41        return data
42
43    def frame_inbound_complete(
44        self, proto: Union[FrameDecoder, FrameProtocol], fin: bool
45    ) -> Union[bytes, CloseReason, None]:
46        pass
47
48    def frame_outbound(
49        self,
50        proto: Union[FrameDecoder, FrameProtocol],
51        opcode: Opcode,
52        rsv: RsvBits,
53        data: bytes,
54        fin: bool,
55    ) -> Tuple[RsvBits, bytes]:
56        return (rsv, data)
57
58
59class PerMessageDeflate(Extension):
60    name = "permessage-deflate"
61
62    DEFAULT_CLIENT_MAX_WINDOW_BITS = 15
63    DEFAULT_SERVER_MAX_WINDOW_BITS = 15
64
65    def __init__(
66        self,
67        client_no_context_takeover: bool = False,
68        client_max_window_bits: Optional[int] = None,
69        server_no_context_takeover: bool = False,
70        server_max_window_bits: Optional[int] = None,
71    ) -> None:
72        self.client_no_context_takeover = client_no_context_takeover
73        self.server_no_context_takeover = server_no_context_takeover
74        self._client_max_window_bits = self.DEFAULT_CLIENT_MAX_WINDOW_BITS
75        self._server_max_window_bits = self.DEFAULT_SERVER_MAX_WINDOW_BITS
76        if client_max_window_bits is not None:
77            self.client_max_window_bits = client_max_window_bits
78        if server_max_window_bits is not None:
79            self.server_max_window_bits = server_max_window_bits
80
81        self._compressor: Optional[zlib._Compress] = None  # noqa
82        self._decompressor: Optional[zlib._Decompress] = None  # noqa
83        # This refers to the current frame
84        self._inbound_is_compressible: Optional[bool] = None
85        # This refers to the ongoing message (which might span multiple
86        # frames). Only the first frame in a fragmented message is flagged for
87        # compression, so this carries that bit forward.
88        self._inbound_compressed: Optional[bool] = None
89
90        self._enabled = False
91
92    @property
93    def client_max_window_bits(self) -> int:
94        return self._client_max_window_bits
95
96    @client_max_window_bits.setter
97    def client_max_window_bits(self, value: int) -> None:
98        if value < 9 or value > 15:
99            raise ValueError("Window size must be between 9 and 15 inclusive")
100        self._client_max_window_bits = value
101
102    @property
103    def server_max_window_bits(self) -> int:
104        return self._server_max_window_bits
105
106    @server_max_window_bits.setter
107    def server_max_window_bits(self, value: int) -> None:
108        if value < 9 or value > 15:
109            raise ValueError("Window size must be between 9 and 15 inclusive")
110        self._server_max_window_bits = value
111
112    def _compressible_opcode(self, opcode: Opcode) -> bool:
113        return opcode in (Opcode.TEXT, Opcode.BINARY, Opcode.CONTINUATION)
114
115    def enabled(self) -> bool:
116        return self._enabled
117
118    def offer(self) -> Union[bool, str]:
119        parameters = [
120            "client_max_window_bits=%d" % self.client_max_window_bits,
121            "server_max_window_bits=%d" % self.server_max_window_bits,
122        ]
123
124        if self.client_no_context_takeover:
125            parameters.append("client_no_context_takeover")
126        if self.server_no_context_takeover:
127            parameters.append("server_no_context_takeover")
128
129        return "; ".join(parameters)
130
131    def finalize(self, offer: str) -> None:
132        bits = [b.strip() for b in offer.split(";")]
133        for bit in bits[1:]:
134            if bit.startswith("client_no_context_takeover"):
135                self.client_no_context_takeover = True
136            elif bit.startswith("server_no_context_takeover"):
137                self.server_no_context_takeover = True
138            elif bit.startswith("client_max_window_bits"):
139                self.client_max_window_bits = int(bit.split("=", 1)[1].strip())
140            elif bit.startswith("server_max_window_bits"):
141                self.server_max_window_bits = int(bit.split("=", 1)[1].strip())
142
143        self._enabled = True
144
145    def _parse_params(self, params: str) -> Tuple[Optional[int], Optional[int]]:
146        client_max_window_bits = None
147        server_max_window_bits = None
148
149        bits = [b.strip() for b in params.split(";")]
150        for bit in bits[1:]:
151            if bit.startswith("client_no_context_takeover"):
152                self.client_no_context_takeover = True
153            elif bit.startswith("server_no_context_takeover"):
154                self.server_no_context_takeover = True
155            elif bit.startswith("client_max_window_bits"):
156                if "=" in bit:
157                    client_max_window_bits = int(bit.split("=", 1)[1].strip())
158                else:
159                    client_max_window_bits = self.client_max_window_bits
160            elif bit.startswith("server_max_window_bits"):
161                if "=" in bit:
162                    server_max_window_bits = int(bit.split("=", 1)[1].strip())
163                else:
164                    server_max_window_bits = self.server_max_window_bits
165
166        return client_max_window_bits, server_max_window_bits
167
168    def accept(self, offer: str) -> Union[bool, None, str]:
169        client_max_window_bits, server_max_window_bits = self._parse_params(offer)
170
171        parameters = []
172
173        if self.client_no_context_takeover:
174            parameters.append("client_no_context_takeover")
175        if self.server_no_context_takeover:
176            parameters.append("server_no_context_takeover")
177        try:
178            if client_max_window_bits is not None:
179                parameters.append("client_max_window_bits=%d" % client_max_window_bits)
180                self.client_max_window_bits = client_max_window_bits
181            if server_max_window_bits is not None:
182                parameters.append("server_max_window_bits=%d" % server_max_window_bits)
183                self.server_max_window_bits = server_max_window_bits
184        except ValueError:
185            return None
186        else:
187            self._enabled = True
188            return "; ".join(parameters)
189
190    def frame_inbound_header(
191        self,
192        proto: Union[FrameDecoder, FrameProtocol],
193        opcode: Opcode,
194        rsv: RsvBits,
195        payload_length: int,
196    ) -> Union[CloseReason, RsvBits]:
197        if rsv.rsv1 and opcode.iscontrol():
198            return CloseReason.PROTOCOL_ERROR
199        if rsv.rsv1 and opcode is Opcode.CONTINUATION:
200            return CloseReason.PROTOCOL_ERROR
201
202        self._inbound_is_compressible = self._compressible_opcode(opcode)
203
204        if self._inbound_compressed is None:
205            self._inbound_compressed = rsv.rsv1
206            if self._inbound_compressed:
207                assert self._inbound_is_compressible
208                if proto.client:
209                    bits = self.server_max_window_bits
210                else:
211                    bits = self.client_max_window_bits
212                if self._decompressor is None:
213                    self._decompressor = zlib.decompressobj(-int(bits))
214
215        return RsvBits(True, False, False)
216
217    def frame_inbound_payload_data(
218        self, proto: Union[FrameDecoder, FrameProtocol], data: bytes
219    ) -> Union[bytes, CloseReason]:
220        if not self._inbound_compressed or not self._inbound_is_compressible:
221            return data
222        assert self._decompressor is not None
223
224        try:
225            return self._decompressor.decompress(bytes(data))
226        except zlib.error:
227            return CloseReason.INVALID_FRAME_PAYLOAD_DATA
228
229    def frame_inbound_complete(
230        self, proto: Union[FrameDecoder, FrameProtocol], fin: bool
231    ) -> Union[bytes, CloseReason, None]:
232        if not fin:
233            return None
234        if not self._inbound_is_compressible:
235            self._inbound_compressed = None
236            return None
237        if not self._inbound_compressed:
238            self._inbound_compressed = None
239            return None
240        assert self._decompressor is not None
241
242        try:
243            data = self._decompressor.decompress(b"\x00\x00\xff\xff")
244            data += self._decompressor.flush()
245        except zlib.error:
246            return CloseReason.INVALID_FRAME_PAYLOAD_DATA
247
248        if proto.client:
249            no_context_takeover = self.server_no_context_takeover
250        else:
251            no_context_takeover = self.client_no_context_takeover
252
253        if no_context_takeover:
254            self._decompressor = None
255
256        self._inbound_compressed = None
257
258        return data
259
260    def frame_outbound(
261        self,
262        proto: Union[FrameDecoder, FrameProtocol],
263        opcode: Opcode,
264        rsv: RsvBits,
265        data: bytes,
266        fin: bool,
267    ) -> Tuple[RsvBits, bytes]:
268        if not self._compressible_opcode(opcode):
269            return (rsv, data)
270
271        if opcode is not Opcode.CONTINUATION:
272            rsv = RsvBits(True, *rsv[1:])
273
274        if self._compressor is None:
275            assert opcode is not Opcode.CONTINUATION
276            if proto.client:
277                bits = self.client_max_window_bits
278            else:
279                bits = self.server_max_window_bits
280            self._compressor = zlib.compressobj(
281                zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -int(bits)
282            )
283
284        data = self._compressor.compress(bytes(data))
285
286        if fin:
287            data += self._compressor.flush(zlib.Z_SYNC_FLUSH)
288            data = data[:-4]
289
290            if proto.client:
291                no_context_takeover = self.client_no_context_takeover
292            else:
293                no_context_takeover = self.server_no_context_takeover
294
295            if no_context_takeover:
296                self._compressor = None
297
298        return (rsv, data)
299
300    def __repr__(self) -> str:
301        descr = ["client_max_window_bits=%d" % self.client_max_window_bits]
302        if self.client_no_context_takeover:
303            descr.append("client_no_context_takeover")
304        descr.append("server_max_window_bits=%d" % self.server_max_window_bits)
305        if self.server_no_context_takeover:
306            descr.append("server_no_context_takeover")
307
308        return "<{} {}>".format(self.__class__.__name__, "; ".join(descr))
309
310
311#: SUPPORTED_EXTENSIONS maps all supported extension names to their class.
312#: This can be used to iterate all supported extensions of wsproto, instantiate
313#: new extensions based on their name, or check if a given extension is
314#: supported or not.
315SUPPORTED_EXTENSIONS = {PerMessageDeflate.name: PerMessageDeflate}
316