1""" 2websocket - WebSocket client library for Python 3 4Copyright (C) 2010 Hiroki Ohtani(liris) 5 6 This library is free software; you can redistribute it and/or 7 modify it under the terms of the GNU Lesser General Public 8 License as published by the Free Software Foundation; either 9 version 2.1 of the License, or (at your option) any later version. 10 11 This library is distributed in the hope that it will be useful, 12 but WITHOUT ANY WARRANTY; without even the implied warranty of 13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 14 Lesser General Public License for more details. 15 16 You should have received a copy of the GNU Lesser General Public 17 License along with this library; if not, write to the Free Software 18 Foundation, Inc., 51 Franklin Street, Fifth Floor, 19 Boston, MA 02110-1335 USA 20 21""" 22import array 23import os 24import struct 25 26import six 27 28from ._exceptions import * 29from ._utils import validate_utf8 30from threading import Lock 31 32try: 33 if six.PY3: 34 import numpy 35 else: 36 numpy = None 37except ImportError: 38 numpy = None 39 40try: 41 # If wsaccel is available we use compiled routines to mask data. 42 if not numpy: 43 from wsaccel.xormask import XorMaskerSimple 44 45 def _mask(_m, _d): 46 return XorMaskerSimple(_m).process(_d) 47except ImportError: 48 # wsaccel is not available, we rely on python implementations. 49 def _mask(_m, _d): 50 for i in range(len(_d)): 51 _d[i] ^= _m[i % 4] 52 53 if six.PY3: 54 return _d.tobytes() 55 else: 56 return _d.tostring() 57 58 59__all__ = [ 60 'ABNF', 'continuous_frame', 'frame_buffer', 61 'STATUS_NORMAL', 62 'STATUS_GOING_AWAY', 63 'STATUS_PROTOCOL_ERROR', 64 'STATUS_UNSUPPORTED_DATA_TYPE', 65 'STATUS_STATUS_NOT_AVAILABLE', 66 'STATUS_ABNORMAL_CLOSED', 67 'STATUS_INVALID_PAYLOAD', 68 'STATUS_POLICY_VIOLATION', 69 'STATUS_MESSAGE_TOO_BIG', 70 'STATUS_INVALID_EXTENSION', 71 'STATUS_UNEXPECTED_CONDITION', 72 'STATUS_BAD_GATEWAY', 73 'STATUS_TLS_HANDSHAKE_ERROR', 74] 75 76# closing frame status codes. 77STATUS_NORMAL = 1000 78STATUS_GOING_AWAY = 1001 79STATUS_PROTOCOL_ERROR = 1002 80STATUS_UNSUPPORTED_DATA_TYPE = 1003 81STATUS_STATUS_NOT_AVAILABLE = 1005 82STATUS_ABNORMAL_CLOSED = 1006 83STATUS_INVALID_PAYLOAD = 1007 84STATUS_POLICY_VIOLATION = 1008 85STATUS_MESSAGE_TOO_BIG = 1009 86STATUS_INVALID_EXTENSION = 1010 87STATUS_UNEXPECTED_CONDITION = 1011 88STATUS_BAD_GATEWAY = 1014 89STATUS_TLS_HANDSHAKE_ERROR = 1015 90 91VALID_CLOSE_STATUS = ( 92 STATUS_NORMAL, 93 STATUS_GOING_AWAY, 94 STATUS_PROTOCOL_ERROR, 95 STATUS_UNSUPPORTED_DATA_TYPE, 96 STATUS_INVALID_PAYLOAD, 97 STATUS_POLICY_VIOLATION, 98 STATUS_MESSAGE_TOO_BIG, 99 STATUS_INVALID_EXTENSION, 100 STATUS_UNEXPECTED_CONDITION, 101 STATUS_BAD_GATEWAY, 102) 103 104 105class ABNF(object): 106 """ 107 ABNF frame class. 108 see http://tools.ietf.org/html/rfc5234 109 and http://tools.ietf.org/html/rfc6455#section-5.2 110 """ 111 112 # operation code values. 113 OPCODE_CONT = 0x0 114 OPCODE_TEXT = 0x1 115 OPCODE_BINARY = 0x2 116 OPCODE_CLOSE = 0x8 117 OPCODE_PING = 0x9 118 OPCODE_PONG = 0xa 119 120 # available operation code value tuple 121 OPCODES = (OPCODE_CONT, OPCODE_TEXT, OPCODE_BINARY, OPCODE_CLOSE, 122 OPCODE_PING, OPCODE_PONG) 123 124 # opcode human readable string 125 OPCODE_MAP = { 126 OPCODE_CONT: "cont", 127 OPCODE_TEXT: "text", 128 OPCODE_BINARY: "binary", 129 OPCODE_CLOSE: "close", 130 OPCODE_PING: "ping", 131 OPCODE_PONG: "pong" 132 } 133 134 # data length threshold. 135 LENGTH_7 = 0x7e 136 LENGTH_16 = 1 << 16 137 LENGTH_63 = 1 << 63 138 139 def __init__(self, fin=0, rsv1=0, rsv2=0, rsv3=0, 140 opcode=OPCODE_TEXT, mask=1, data=""): 141 """ 142 Constructor for ABNF. 143 please check RFC for arguments. 144 """ 145 self.fin = fin 146 self.rsv1 = rsv1 147 self.rsv2 = rsv2 148 self.rsv3 = rsv3 149 self.opcode = opcode 150 self.mask = mask 151 if data is None: 152 data = "" 153 self.data = data 154 self.get_mask_key = os.urandom 155 156 def validate(self, skip_utf8_validation=False): 157 """ 158 validate the ABNF frame. 159 skip_utf8_validation: skip utf8 validation. 160 """ 161 if self.rsv1 or self.rsv2 or self.rsv3: 162 raise WebSocketProtocolException("rsv is not implemented, yet") 163 164 if self.opcode not in ABNF.OPCODES: 165 raise WebSocketProtocolException("Invalid opcode %r", self.opcode) 166 167 if self.opcode == ABNF.OPCODE_PING and not self.fin: 168 raise WebSocketProtocolException("Invalid ping frame.") 169 170 if self.opcode == ABNF.OPCODE_CLOSE: 171 l = len(self.data) 172 if not l: 173 return 174 if l == 1 or l >= 126: 175 raise WebSocketProtocolException("Invalid close frame.") 176 if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]): 177 raise WebSocketProtocolException("Invalid close frame.") 178 179 code = 256 * \ 180 six.byte2int(self.data[0:1]) + six.byte2int(self.data[1:2]) 181 if not self._is_valid_close_status(code): 182 raise WebSocketProtocolException("Invalid close opcode.") 183 184 @staticmethod 185 def _is_valid_close_status(code): 186 return code in VALID_CLOSE_STATUS or (3000 <= code < 5000) 187 188 def __str__(self): 189 return "fin=" + str(self.fin) \ 190 + " opcode=" + str(self.opcode) \ 191 + " data=" + str(self.data) 192 193 @staticmethod 194 def create_frame(data, opcode, fin=1): 195 """ 196 create frame to send text, binary and other data. 197 198 data: data to send. This is string value(byte array). 199 if opcode is OPCODE_TEXT and this value is unicode, 200 data value is converted into unicode string, automatically. 201 202 opcode: operation code. please see OPCODE_XXX. 203 204 fin: fin flag. if set to 0, create continue fragmentation. 205 """ 206 if opcode == ABNF.OPCODE_TEXT and isinstance(data, six.text_type): 207 data = data.encode("utf-8") 208 # mask must be set if send data from client 209 return ABNF(fin, 0, 0, 0, opcode, 1, data) 210 211 def format(self): 212 """ 213 format this object to string(byte array) to send data to server. 214 """ 215 if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]): 216 raise ValueError("not 0 or 1") 217 if self.opcode not in ABNF.OPCODES: 218 raise ValueError("Invalid OPCODE") 219 length = len(self.data) 220 if length >= ABNF.LENGTH_63: 221 raise ValueError("data is too long") 222 223 frame_header = chr(self.fin << 7 224 | self.rsv1 << 6 | self.rsv2 << 5 | self.rsv3 << 4 225 | self.opcode) 226 if length < ABNF.LENGTH_7: 227 frame_header += chr(self.mask << 7 | length) 228 frame_header = six.b(frame_header) 229 elif length < ABNF.LENGTH_16: 230 frame_header += chr(self.mask << 7 | 0x7e) 231 frame_header = six.b(frame_header) 232 frame_header += struct.pack("!H", length) 233 else: 234 frame_header += chr(self.mask << 7 | 0x7f) 235 frame_header = six.b(frame_header) 236 frame_header += struct.pack("!Q", length) 237 238 if not self.mask: 239 return frame_header + self.data 240 else: 241 mask_key = self.get_mask_key(4) 242 return frame_header + self._get_masked(mask_key) 243 244 def _get_masked(self, mask_key): 245 s = ABNF.mask(mask_key, self.data) 246 247 if isinstance(mask_key, six.text_type): 248 mask_key = mask_key.encode('utf-8') 249 250 return mask_key + s 251 252 @staticmethod 253 def mask(mask_key, data): 254 """ 255 mask or unmask data. Just do xor for each byte 256 257 mask_key: 4 byte string(byte). 258 259 data: data to mask/unmask. 260 """ 261 if data is None: 262 data = "" 263 264 if isinstance(mask_key, six.text_type): 265 mask_key = six.b(mask_key) 266 267 if isinstance(data, six.text_type): 268 data = six.b(data) 269 270 if numpy: 271 origlen = len(data) 272 _mask_key = mask_key[3] << 24 | mask_key[2] << 16 | mask_key[1] << 8 | mask_key[0] 273 274 # We need data to be a multiple of four... 275 data += bytes(" " * (4 - (len(data) % 4)), "us-ascii") 276 a = numpy.frombuffer(data, dtype="uint32") 277 masked = numpy.bitwise_xor(a, [_mask_key]).astype("uint32") 278 if len(data) > origlen: 279 return masked.tobytes()[:origlen] 280 return masked.tobytes() 281 else: 282 _m = array.array("B", mask_key) 283 _d = array.array("B", data) 284 return _mask(_m, _d) 285 286 287class frame_buffer(object): 288 _HEADER_MASK_INDEX = 5 289 _HEADER_LENGTH_INDEX = 6 290 291 def __init__(self, recv_fn, skip_utf8_validation): 292 self.recv = recv_fn 293 self.skip_utf8_validation = skip_utf8_validation 294 # Buffers over the packets from the layer beneath until desired amount 295 # bytes of bytes are received. 296 self.recv_buffer = [] 297 self.clear() 298 self.lock = Lock() 299 300 def clear(self): 301 self.header = None 302 self.length = None 303 self.mask = None 304 305 def has_received_header(self): 306 return self.header is None 307 308 def recv_header(self): 309 header = self.recv_strict(2) 310 b1 = header[0] 311 312 if six.PY2: 313 b1 = ord(b1) 314 315 fin = b1 >> 7 & 1 316 rsv1 = b1 >> 6 & 1 317 rsv2 = b1 >> 5 & 1 318 rsv3 = b1 >> 4 & 1 319 opcode = b1 & 0xf 320 b2 = header[1] 321 322 if six.PY2: 323 b2 = ord(b2) 324 325 has_mask = b2 >> 7 & 1 326 length_bits = b2 & 0x7f 327 328 self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits) 329 330 def has_mask(self): 331 if not self.header: 332 return False 333 return self.header[frame_buffer._HEADER_MASK_INDEX] 334 335 def has_received_length(self): 336 return self.length is None 337 338 def recv_length(self): 339 bits = self.header[frame_buffer._HEADER_LENGTH_INDEX] 340 length_bits = bits & 0x7f 341 if length_bits == 0x7e: 342 v = self.recv_strict(2) 343 self.length = struct.unpack("!H", v)[0] 344 elif length_bits == 0x7f: 345 v = self.recv_strict(8) 346 self.length = struct.unpack("!Q", v)[0] 347 else: 348 self.length = length_bits 349 350 def has_received_mask(self): 351 return self.mask is None 352 353 def recv_mask(self): 354 self.mask = self.recv_strict(4) if self.has_mask() else "" 355 356 def recv_frame(self): 357 358 with self.lock: 359 # Header 360 if self.has_received_header(): 361 self.recv_header() 362 (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header 363 364 # Frame length 365 if self.has_received_length(): 366 self.recv_length() 367 length = self.length 368 369 # Mask 370 if self.has_received_mask(): 371 self.recv_mask() 372 mask = self.mask 373 374 # Payload 375 payload = self.recv_strict(length) 376 if has_mask: 377 payload = ABNF.mask(mask, payload) 378 379 # Reset for next frame 380 self.clear() 381 382 frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload) 383 frame.validate(self.skip_utf8_validation) 384 385 return frame 386 387 def recv_strict(self, bufsize): 388 shortage = bufsize - sum(len(x) for x in self.recv_buffer) 389 while shortage > 0: 390 # Limit buffer size that we pass to socket.recv() to avoid 391 # fragmenting the heap -- the number of bytes recv() actually 392 # reads is limited by socket buffer and is relatively small, 393 # yet passing large numbers repeatedly causes lots of large 394 # buffers allocated and then shrunk, which results in 395 # fragmentation. 396 bytes_ = self.recv(min(16384, shortage)) 397 self.recv_buffer.append(bytes_) 398 shortage -= len(bytes_) 399 400 unified = six.b("").join(self.recv_buffer) 401 402 if shortage == 0: 403 self.recv_buffer = [] 404 return unified 405 else: 406 self.recv_buffer = [unified[bufsize:]] 407 return unified[:bufsize] 408 409 410class continuous_frame(object): 411 412 def __init__(self, fire_cont_frame, skip_utf8_validation): 413 self.fire_cont_frame = fire_cont_frame 414 self.skip_utf8_validation = skip_utf8_validation 415 self.cont_data = None 416 self.recving_frames = None 417 418 def validate(self, frame): 419 if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT: 420 raise WebSocketProtocolException("Illegal frame") 421 if self.recving_frames and \ 422 frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): 423 raise WebSocketProtocolException("Illegal frame") 424 425 def add(self, frame): 426 if self.cont_data: 427 self.cont_data[1] += frame.data 428 else: 429 if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): 430 self.recving_frames = frame.opcode 431 self.cont_data = [frame.opcode, frame.data] 432 433 if frame.fin: 434 self.recving_frames = None 435 436 def is_fire(self, frame): 437 return frame.fin or self.fire_cont_frame 438 439 def extract(self, frame): 440 data = self.cont_data 441 self.cont_data = None 442 frame.data = data[1] 443 if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data): 444 raise WebSocketPayloadException( 445 "cannot decode: " + repr(frame.data)) 446 447 return [data[0], frame] 448