1""" 2wsproto/extensions 3~~~~~~~~~~~~~~~~~~ 4 5WebSocket extensions. 6""" 7 8import zlib 9from typing import Optional, Tuple, Union 10 11from .frame_protocol import CloseReason, FrameDecoder, FrameProtocol, Opcode, RsvBits 12 13 14class Extension: 15 name: str 16 17 def enabled(self) -> bool: 18 return False 19 20 def offer(self) -> Union[bool, str]: 21 pass 22 23 def accept(self, offer: str) -> Optional[Union[bool, str]]: 24 pass 25 26 def finalize(self, offer: str) -> None: 27 pass 28 29 def frame_inbound_header( 30 self, 31 proto: Union[FrameDecoder, FrameProtocol], 32 opcode: Opcode, 33 rsv: RsvBits, 34 payload_length: int, 35 ) -> Union[CloseReason, RsvBits]: 36 return RsvBits(False, False, False) 37 38 def frame_inbound_payload_data( 39 self, proto: Union[FrameDecoder, FrameProtocol], data: bytes 40 ) -> Union[bytes, CloseReason]: 41 return data 42 43 def frame_inbound_complete( 44 self, proto: Union[FrameDecoder, FrameProtocol], fin: bool 45 ) -> Union[bytes, CloseReason, None]: 46 pass 47 48 def frame_outbound( 49 self, 50 proto: Union[FrameDecoder, FrameProtocol], 51 opcode: Opcode, 52 rsv: RsvBits, 53 data: bytes, 54 fin: bool, 55 ) -> Tuple[RsvBits, bytes]: 56 return (rsv, data) 57 58 59class PerMessageDeflate(Extension): 60 name = "permessage-deflate" 61 62 DEFAULT_CLIENT_MAX_WINDOW_BITS = 15 63 DEFAULT_SERVER_MAX_WINDOW_BITS = 15 64 65 def __init__( 66 self, 67 client_no_context_takeover: bool = False, 68 client_max_window_bits: Optional[int] = None, 69 server_no_context_takeover: bool = False, 70 server_max_window_bits: Optional[int] = None, 71 ) -> None: 72 self.client_no_context_takeover = client_no_context_takeover 73 self.server_no_context_takeover = server_no_context_takeover 74 self._client_max_window_bits = self.DEFAULT_CLIENT_MAX_WINDOW_BITS 75 self._server_max_window_bits = self.DEFAULT_SERVER_MAX_WINDOW_BITS 76 if client_max_window_bits is not None: 77 self.client_max_window_bits = client_max_window_bits 78 if server_max_window_bits is not None: 79 self.server_max_window_bits = server_max_window_bits 80 81 self._compressor: Optional[zlib._Compress] = None # noqa 82 self._decompressor: Optional[zlib._Decompress] = None # noqa 83 # This refers to the current frame 84 self._inbound_is_compressible: Optional[bool] = None 85 # This refers to the ongoing message (which might span multiple 86 # frames). Only the first frame in a fragmented message is flagged for 87 # compression, so this carries that bit forward. 88 self._inbound_compressed: Optional[bool] = None 89 90 self._enabled = False 91 92 @property 93 def client_max_window_bits(self) -> int: 94 return self._client_max_window_bits 95 96 @client_max_window_bits.setter 97 def client_max_window_bits(self, value: int) -> None: 98 if value < 9 or value > 15: 99 raise ValueError("Window size must be between 9 and 15 inclusive") 100 self._client_max_window_bits = value 101 102 @property 103 def server_max_window_bits(self) -> int: 104 return self._server_max_window_bits 105 106 @server_max_window_bits.setter 107 def server_max_window_bits(self, value: int) -> None: 108 if value < 9 or value > 15: 109 raise ValueError("Window size must be between 9 and 15 inclusive") 110 self._server_max_window_bits = value 111 112 def _compressible_opcode(self, opcode: Opcode) -> bool: 113 return opcode in (Opcode.TEXT, Opcode.BINARY, Opcode.CONTINUATION) 114 115 def enabled(self) -> bool: 116 return self._enabled 117 118 def offer(self) -> Union[bool, str]: 119 parameters = [ 120 "client_max_window_bits=%d" % self.client_max_window_bits, 121 "server_max_window_bits=%d" % self.server_max_window_bits, 122 ] 123 124 if self.client_no_context_takeover: 125 parameters.append("client_no_context_takeover") 126 if self.server_no_context_takeover: 127 parameters.append("server_no_context_takeover") 128 129 return "; ".join(parameters) 130 131 def finalize(self, offer: str) -> None: 132 bits = [b.strip() for b in offer.split(";")] 133 for bit in bits[1:]: 134 if bit.startswith("client_no_context_takeover"): 135 self.client_no_context_takeover = True 136 elif bit.startswith("server_no_context_takeover"): 137 self.server_no_context_takeover = True 138 elif bit.startswith("client_max_window_bits"): 139 self.client_max_window_bits = int(bit.split("=", 1)[1].strip()) 140 elif bit.startswith("server_max_window_bits"): 141 self.server_max_window_bits = int(bit.split("=", 1)[1].strip()) 142 143 self._enabled = True 144 145 def _parse_params(self, params: str) -> Tuple[Optional[int], Optional[int]]: 146 client_max_window_bits = None 147 server_max_window_bits = None 148 149 bits = [b.strip() for b in params.split(";")] 150 for bit in bits[1:]: 151 if bit.startswith("client_no_context_takeover"): 152 self.client_no_context_takeover = True 153 elif bit.startswith("server_no_context_takeover"): 154 self.server_no_context_takeover = True 155 elif bit.startswith("client_max_window_bits"): 156 if "=" in bit: 157 client_max_window_bits = int(bit.split("=", 1)[1].strip()) 158 else: 159 client_max_window_bits = self.client_max_window_bits 160 elif bit.startswith("server_max_window_bits"): 161 if "=" in bit: 162 server_max_window_bits = int(bit.split("=", 1)[1].strip()) 163 else: 164 server_max_window_bits = self.server_max_window_bits 165 166 return client_max_window_bits, server_max_window_bits 167 168 def accept(self, offer: str) -> Union[bool, None, str]: 169 client_max_window_bits, server_max_window_bits = self._parse_params(offer) 170 171 parameters = [] 172 173 if self.client_no_context_takeover: 174 parameters.append("client_no_context_takeover") 175 if self.server_no_context_takeover: 176 parameters.append("server_no_context_takeover") 177 try: 178 if client_max_window_bits is not None: 179 parameters.append("client_max_window_bits=%d" % client_max_window_bits) 180 self.client_max_window_bits = client_max_window_bits 181 if server_max_window_bits is not None: 182 parameters.append("server_max_window_bits=%d" % server_max_window_bits) 183 self.server_max_window_bits = server_max_window_bits 184 except ValueError: 185 return None 186 else: 187 self._enabled = True 188 return "; ".join(parameters) 189 190 def frame_inbound_header( 191 self, 192 proto: Union[FrameDecoder, FrameProtocol], 193 opcode: Opcode, 194 rsv: RsvBits, 195 payload_length: int, 196 ) -> Union[CloseReason, RsvBits]: 197 if rsv.rsv1 and opcode.iscontrol(): 198 return CloseReason.PROTOCOL_ERROR 199 if rsv.rsv1 and opcode is Opcode.CONTINUATION: 200 return CloseReason.PROTOCOL_ERROR 201 202 self._inbound_is_compressible = self._compressible_opcode(opcode) 203 204 if self._inbound_compressed is None: 205 self._inbound_compressed = rsv.rsv1 206 if self._inbound_compressed: 207 assert self._inbound_is_compressible 208 if proto.client: 209 bits = self.server_max_window_bits 210 else: 211 bits = self.client_max_window_bits 212 if self._decompressor is None: 213 self._decompressor = zlib.decompressobj(-int(bits)) 214 215 return RsvBits(True, False, False) 216 217 def frame_inbound_payload_data( 218 self, proto: Union[FrameDecoder, FrameProtocol], data: bytes 219 ) -> Union[bytes, CloseReason]: 220 if not self._inbound_compressed or not self._inbound_is_compressible: 221 return data 222 assert self._decompressor is not None 223 224 try: 225 return self._decompressor.decompress(bytes(data)) 226 except zlib.error: 227 return CloseReason.INVALID_FRAME_PAYLOAD_DATA 228 229 def frame_inbound_complete( 230 self, proto: Union[FrameDecoder, FrameProtocol], fin: bool 231 ) -> Union[bytes, CloseReason, None]: 232 if not fin: 233 return None 234 if not self._inbound_is_compressible: 235 self._inbound_compressed = None 236 return None 237 if not self._inbound_compressed: 238 self._inbound_compressed = None 239 return None 240 assert self._decompressor is not None 241 242 try: 243 data = self._decompressor.decompress(b"\x00\x00\xff\xff") 244 data += self._decompressor.flush() 245 except zlib.error: 246 return CloseReason.INVALID_FRAME_PAYLOAD_DATA 247 248 if proto.client: 249 no_context_takeover = self.server_no_context_takeover 250 else: 251 no_context_takeover = self.client_no_context_takeover 252 253 if no_context_takeover: 254 self._decompressor = None 255 256 self._inbound_compressed = None 257 258 return data 259 260 def frame_outbound( 261 self, 262 proto: Union[FrameDecoder, FrameProtocol], 263 opcode: Opcode, 264 rsv: RsvBits, 265 data: bytes, 266 fin: bool, 267 ) -> Tuple[RsvBits, bytes]: 268 if not self._compressible_opcode(opcode): 269 return (rsv, data) 270 271 if opcode is not Opcode.CONTINUATION: 272 rsv = RsvBits(True, *rsv[1:]) 273 274 if self._compressor is None: 275 assert opcode is not Opcode.CONTINUATION 276 if proto.client: 277 bits = self.client_max_window_bits 278 else: 279 bits = self.server_max_window_bits 280 self._compressor = zlib.compressobj( 281 zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -int(bits) 282 ) 283 284 data = self._compressor.compress(bytes(data)) 285 286 if fin: 287 data += self._compressor.flush(zlib.Z_SYNC_FLUSH) 288 data = data[:-4] 289 290 if proto.client: 291 no_context_takeover = self.client_no_context_takeover 292 else: 293 no_context_takeover = self.server_no_context_takeover 294 295 if no_context_takeover: 296 self._compressor = None 297 298 return (rsv, data) 299 300 def __repr__(self) -> str: 301 descr = ["client_max_window_bits=%d" % self.client_max_window_bits] 302 if self.client_no_context_takeover: 303 descr.append("client_no_context_takeover") 304 descr.append("server_max_window_bits=%d" % self.server_max_window_bits) 305 if self.server_no_context_takeover: 306 descr.append("server_no_context_takeover") 307 308 return "<{} {}>".format(self.__class__.__name__, "; ".join(descr)) 309 310 311#: SUPPORTED_EXTENSIONS maps all supported extension names to their class. 312#: This can be used to iterate all supported extensions of wsproto, instantiate 313#: new extensions based on their name, or check if a given extension is 314#: supported or not. 315SUPPORTED_EXTENSIONS = {PerMessageDeflate.name: PerMessageDeflate} 316