1# Copyright 2011, Google Inc. 2# All rights reserved. 3# 4# Redistribution and use in source and binary forms, with or without 5# modification, are permitted provided that the following conditions are 6# met: 7# 8# * Redistributions of source code must retain the above copyright 9# notice, this list of conditions and the following disclaimer. 10# * Redistributions in binary form must reproduce the above 11# copyright notice, this list of conditions and the following disclaimer 12# in the documentation and/or other materials provided with the 13# distribution. 14# * Neither the name of Google Inc. nor the names of its 15# contributors may be used to endorse or promote products derived from 16# this software without specific prior written permission. 17# 18# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29"""WebSocket utilities.""" 30 31from __future__ import absolute_import 32import array 33import errno 34import logging 35import os 36import re 37import six 38from six.moves import map 39from six.moves import range 40import socket 41import struct 42import zlib 43 44try: 45 from mod_pywebsocket import fast_masking 46except ImportError: 47 pass 48 49 50def prepend_message_to_exception(message, exc): 51 """Prepend message to the exception.""" 52 exc.args = (message + str(exc), ) 53 return 54 55 56def __translate_interp(interp, cygwin_path): 57 """Translate interp program path for Win32 python to run cygwin program 58 (e.g. perl). Note that it doesn't support path that contains space, 59 which is typically true for Unix, where #!-script is written. 60 For Win32 python, cygwin_path is a directory of cygwin binaries. 61 62 Args: 63 interp: interp command line 64 cygwin_path: directory name of cygwin binary, or None 65 Returns: 66 translated interp command line. 67 """ 68 if not cygwin_path: 69 return interp 70 m = re.match('^[^ ]*/([^ ]+)( .*)?', interp) 71 if m: 72 cmd = os.path.join(cygwin_path, m.group(1)) 73 return cmd + m.group(2) 74 return interp 75 76 77def get_script_interp(script_path, cygwin_path=None): 78 r"""Get #!-interpreter command line from the script. 79 80 It also fixes command path. When Cygwin Python is used, e.g. in WebKit, 81 it could run "/usr/bin/perl -wT hello.pl". 82 When Win32 Python is used, e.g. in Chromium, it couldn't. So, fix 83 "/usr/bin/perl" to "<cygwin_path>\perl.exe". 84 85 Args: 86 script_path: pathname of the script 87 cygwin_path: directory name of cygwin binary, or None 88 Returns: 89 #!-interpreter command line, or None if it is not #!-script. 90 """ 91 fp = open(script_path) 92 line = fp.readline() 93 fp.close() 94 m = re.match('^#!(.*)', line) 95 if m: 96 return __translate_interp(m.group(1), cygwin_path) 97 return None 98 99 100def wrap_popen3_for_win(cygwin_path): 101 """Wrap popen3 to support #!-script on Windows. 102 103 Args: 104 cygwin_path: path for cygwin binary if command path is needed to be 105 translated. None if no translation required. 106 """ 107 __orig_popen3 = os.popen3 108 109 def __wrap_popen3(cmd, mode='t', bufsize=-1): 110 cmdline = cmd.split(' ') 111 interp = get_script_interp(cmdline[0], cygwin_path) 112 if interp: 113 cmd = interp + ' ' + cmd 114 return __orig_popen3(cmd, mode, bufsize) 115 116 os.popen3 = __wrap_popen3 117 118 119def hexify(s): 120 return ' '.join(['%02x' % x for x in six.iterbytes(s)]) 121 122 123def get_class_logger(o): 124 """Return the logging class information.""" 125 return logging.getLogger('%s.%s' % 126 (o.__class__.__module__, o.__class__.__name__)) 127 128 129def pack_byte(b): 130 """Pack an integer to network-ordered byte""" 131 return struct.pack('!B', b) 132 133 134class NoopMasker(object): 135 """A NoOp masking object. 136 137 This has the same interface as RepeatedXorMasker but just returns 138 the string passed in without making any change. 139 """ 140 def __init__(self): 141 """NoOp.""" 142 pass 143 144 def mask(self, s): 145 """NoOp.""" 146 return s 147 148 149class RepeatedXorMasker(object): 150 """A masking object that applies XOR on the string. 151 152 Applies XOR on the byte string given to mask method with the masking bytes 153 given to the constructor repeatedly. This object remembers the position 154 in the masking bytes the last mask method call ended and resumes from 155 that point on the next mask method call. 156 """ 157 def __init__(self, masking_key): 158 self._masking_key = masking_key 159 self._masking_key_index = 0 160 161 def _mask_using_swig(self, s): 162 """Perform the mask via SWIG.""" 163 masked_data = fast_masking.mask(s, self._masking_key, 164 self._masking_key_index) 165 self._masking_key_index = ((self._masking_key_index + len(s)) % 166 len(self._masking_key)) 167 return masked_data 168 169 def _mask_using_array(self, s): 170 """Perform the mask via python.""" 171 if isinstance(s, six.text_type): 172 raise Exception( 173 'Masking Operation should not process unicode strings') 174 175 result = bytearray(s) 176 177 # Use temporary local variables to eliminate the cost to access 178 # attributes 179 masking_key = [c for c in six.iterbytes(self._masking_key)] 180 masking_key_size = len(masking_key) 181 masking_key_index = self._masking_key_index 182 183 for i in range(len(result)): 184 result[i] ^= masking_key[masking_key_index] 185 masking_key_index = (masking_key_index + 1) % masking_key_size 186 187 self._masking_key_index = masking_key_index 188 189 return bytes(result) 190 191 if 'fast_masking' in globals(): 192 mask = _mask_using_swig 193 else: 194 mask = _mask_using_array 195 196 197# By making wbits option negative, we can suppress CMF/FLG (2 octet) and 198# ADLER32 (4 octet) fields of zlib so that we can use zlib module just as 199# deflate library. DICTID won't be added as far as we don't set dictionary. 200# LZ77 window of 32K will be used for both compression and decompression. 201# For decompression, we can just use 32K to cover any windows size. For 202# compression, we use 32K so receivers must use 32K. 203# 204# Compression level is Z_DEFAULT_COMPRESSION. We don't have to match level 205# to decode. 206# 207# See zconf.h, deflate.cc, inflate.cc of zlib library, and zlibmodule.c of 208# Python. See also RFC1950 (ZLIB 3.3). 209 210 211class _Deflater(object): 212 def __init__(self, window_bits): 213 self._logger = get_class_logger(self) 214 215 # Using the smallest window bits of 9 for generating input frames. 216 # On WebSocket spec, the smallest window bit is 8. However, zlib does 217 # not accept window_bit = 8. 218 # 219 # Because of a zlib deflate quirk, back-references will not use the 220 # entire range of 1 << window_bits, but will instead use a restricted 221 # range of (1 << window_bits) - 262. With an increased window_bits = 9, 222 # back-references will be within a range of 250. These can still be 223 # decompressed with window_bits = 8 and the 256-byte window used there. 224 # 225 # Similar disscussions can be found in https://crbug.com/691074 226 window_bits = max(window_bits, 9) 227 228 self._compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, 229 zlib.DEFLATED, -window_bits) 230 231 def compress(self, bytes): 232 compressed_bytes = self._compress.compress(bytes) 233 self._logger.debug('Compress input %r', bytes) 234 self._logger.debug('Compress result %r', compressed_bytes) 235 return compressed_bytes 236 237 def compress_and_flush(self, bytes): 238 compressed_bytes = self._compress.compress(bytes) 239 compressed_bytes += self._compress.flush(zlib.Z_SYNC_FLUSH) 240 self._logger.debug('Compress input %r', bytes) 241 self._logger.debug('Compress result %r', compressed_bytes) 242 return compressed_bytes 243 244 def compress_and_finish(self, bytes): 245 compressed_bytes = self._compress.compress(bytes) 246 compressed_bytes += self._compress.flush(zlib.Z_FINISH) 247 self._logger.debug('Compress input %r', bytes) 248 self._logger.debug('Compress result %r', compressed_bytes) 249 return compressed_bytes 250 251 252class _Inflater(object): 253 def __init__(self, window_bits): 254 self._logger = get_class_logger(self) 255 self._window_bits = window_bits 256 257 self._unconsumed = b'' 258 259 self.reset() 260 261 def decompress(self, size): 262 if not (size == -1 or size > 0): 263 raise Exception('size must be -1 or positive') 264 265 data = b'' 266 267 while True: 268 data += self._decompress.decompress(self._unconsumed, 269 max(0, size - len(data))) 270 self._unconsumed = self._decompress.unconsumed_tail 271 if self._decompress.unused_data: 272 # Encountered a last block (i.e. a block with BFINAL = 1) and 273 # found a new stream (unused_data). We cannot use the same 274 # zlib.Decompress object for the new stream. Create a new 275 # Decompress object to decompress the new one. 276 # 277 # It's fine to ignore unconsumed_tail if unused_data is not 278 # empty. 279 self._unconsumed = self._decompress.unused_data 280 self.reset() 281 if size >= 0 and len(data) == size: 282 # data is filled. Don't call decompress again. 283 break 284 else: 285 # Re-invoke Decompress.decompress to try to decompress all 286 # available bytes before invoking read which blocks until 287 # any new byte is available. 288 continue 289 else: 290 # Here, since unused_data is empty, even if unconsumed_tail is 291 # not empty, bytes of requested length are already in data. We 292 # don't have to "continue" here. 293 break 294 295 if data: 296 self._logger.debug('Decompressed %r', data) 297 return data 298 299 def append(self, data): 300 self._logger.debug('Appended %r', data) 301 self._unconsumed += data 302 303 def reset(self): 304 self._logger.debug('Reset') 305 self._decompress = zlib.decompressobj(-self._window_bits) 306 307 308# Compresses/decompresses given octets using the method introduced in RFC1979. 309 310 311class _RFC1979Deflater(object): 312 """A compressor class that applies DEFLATE to given byte sequence and 313 flushes using the algorithm described in the RFC1979 section 2.1. 314 """ 315 def __init__(self, window_bits, no_context_takeover): 316 self._deflater = None 317 if window_bits is None: 318 window_bits = zlib.MAX_WBITS 319 self._window_bits = window_bits 320 self._no_context_takeover = no_context_takeover 321 322 def filter(self, bytes, end=True, bfinal=False): 323 if self._deflater is None: 324 self._deflater = _Deflater(self._window_bits) 325 326 if bfinal: 327 result = self._deflater.compress_and_finish(bytes) 328 # Add a padding block with BFINAL = 0 and BTYPE = 0. 329 result = result + pack_byte(0) 330 self._deflater = None 331 return result 332 333 result = self._deflater.compress_and_flush(bytes) 334 if end: 335 # Strip last 4 octets which is LEN and NLEN field of a 336 # non-compressed block added for Z_SYNC_FLUSH. 337 result = result[:-4] 338 339 if self._no_context_takeover and end: 340 self._deflater = None 341 342 return result 343 344 345class _RFC1979Inflater(object): 346 """A decompressor class a la RFC1979. 347 348 A decompressor class for byte sequence compressed and flushed following 349 the algorithm described in the RFC1979 section 2.1. 350 """ 351 def __init__(self, window_bits=zlib.MAX_WBITS): 352 self._inflater = _Inflater(window_bits) 353 354 def filter(self, bytes): 355 # Restore stripped LEN and NLEN field of a non-compressed block added 356 # for Z_SYNC_FLUSH. 357 self._inflater.append(bytes + b'\x00\x00\xff\xff') 358 return self._inflater.decompress(-1) 359 360 361class DeflateSocket(object): 362 """A wrapper class for socket object to intercept send and recv to perform 363 deflate compression and decompression transparently. 364 """ 365 366 # Size of the buffer passed to recv to receive compressed data. 367 _RECV_SIZE = 4096 368 369 def __init__(self, socket): 370 self._socket = socket 371 372 self._logger = get_class_logger(self) 373 374 self._deflater = _Deflater(zlib.MAX_WBITS) 375 self._inflater = _Inflater(zlib.MAX_WBITS) 376 377 def recv(self, size): 378 """Receives data from the socket specified on the construction up 379 to the specified size. Once any data is available, returns it even 380 if it's smaller than the specified size. 381 """ 382 383 # TODO(tyoshino): Allow call with size=0. It should block until any 384 # decompressed data is available. 385 if size <= 0: 386 raise Exception('Non-positive size passed') 387 while True: 388 data = self._inflater.decompress(size) 389 if len(data) != 0: 390 return data 391 392 read_data = self._socket.recv(DeflateSocket._RECV_SIZE) 393 if not read_data: 394 return b'' 395 self._inflater.append(read_data) 396 397 def sendall(self, bytes): 398 self.send(bytes) 399 400 def send(self, bytes): 401 self._socket.sendall(self._deflater.compress_and_flush(bytes)) 402 return len(bytes) 403 404 405# vi:sts=4 sw=4 et 406