1# Copyright 2011, Google Inc.
2# All rights reserved.
3#
4# Redistribution and use in source and binary forms, with or without
5# modification, are permitted provided that the following conditions are
6# met:
7#
8#     * Redistributions of source code must retain the above copyright
9# notice, this list of conditions and the following disclaimer.
10#     * Redistributions in binary form must reproduce the above
11# copyright notice, this list of conditions and the following disclaimer
12# in the documentation and/or other materials provided with the
13# distribution.
14#     * Neither the name of Google Inc. nor the names of its
15# contributors may be used to endorse or promote products derived from
16# this software without specific prior written permission.
17#
18# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29"""WebSocket utilities."""
30
31from __future__ import absolute_import
32import array
33import errno
34import logging
35import os
36import re
37import six
38from six.moves import map
39from six.moves import range
40import socket
41import struct
42import zlib
43
44try:
45    from mod_pywebsocket import fast_masking
46except ImportError:
47    pass
48
49
50def prepend_message_to_exception(message, exc):
51    """Prepend message to the exception."""
52    exc.args = (message + str(exc), )
53    return
54
55
56def __translate_interp(interp, cygwin_path):
57    """Translate interp program path for Win32 python to run cygwin program
58    (e.g. perl).  Note that it doesn't support path that contains space,
59    which is typically true for Unix, where #!-script is written.
60    For Win32 python, cygwin_path is a directory of cygwin binaries.
61
62    Args:
63      interp: interp command line
64      cygwin_path: directory name of cygwin binary, or None
65    Returns:
66      translated interp command line.
67    """
68    if not cygwin_path:
69        return interp
70    m = re.match('^[^ ]*/([^ ]+)( .*)?', interp)
71    if m:
72        cmd = os.path.join(cygwin_path, m.group(1))
73        return cmd + m.group(2)
74    return interp
75
76
77def get_script_interp(script_path, cygwin_path=None):
78    r"""Get #!-interpreter command line from the script.
79
80    It also fixes command path.  When Cygwin Python is used, e.g. in WebKit,
81    it could run "/usr/bin/perl -wT hello.pl".
82    When Win32 Python is used, e.g. in Chromium, it couldn't.  So, fix
83    "/usr/bin/perl" to "<cygwin_path>\perl.exe".
84
85    Args:
86      script_path: pathname of the script
87      cygwin_path: directory name of cygwin binary, or None
88    Returns:
89      #!-interpreter command line, or None if it is not #!-script.
90    """
91    fp = open(script_path)
92    line = fp.readline()
93    fp.close()
94    m = re.match('^#!(.*)', line)
95    if m:
96        return __translate_interp(m.group(1), cygwin_path)
97    return None
98
99
100def wrap_popen3_for_win(cygwin_path):
101    """Wrap popen3 to support #!-script on Windows.
102
103    Args:
104      cygwin_path:  path for cygwin binary if command path is needed to be
105                    translated.  None if no translation required.
106    """
107    __orig_popen3 = os.popen3
108
109    def __wrap_popen3(cmd, mode='t', bufsize=-1):
110        cmdline = cmd.split(' ')
111        interp = get_script_interp(cmdline[0], cygwin_path)
112        if interp:
113            cmd = interp + ' ' + cmd
114        return __orig_popen3(cmd, mode, bufsize)
115
116    os.popen3 = __wrap_popen3
117
118
119def hexify(s):
120    return ' '.join(['%02x' % x for x in six.iterbytes(s)])
121
122
123def get_class_logger(o):
124    """Return the logging class information."""
125    return logging.getLogger('%s.%s' %
126                             (o.__class__.__module__, o.__class__.__name__))
127
128
129def pack_byte(b):
130    """Pack an integer to network-ordered byte"""
131    return struct.pack('!B', b)
132
133
134class NoopMasker(object):
135    """A NoOp masking object.
136
137    This has the same interface as RepeatedXorMasker but just returns
138    the string passed in without making any change.
139    """
140    def __init__(self):
141        """NoOp."""
142        pass
143
144    def mask(self, s):
145        """NoOp."""
146        return s
147
148
149class RepeatedXorMasker(object):
150    """A masking object that applies XOR on the string.
151
152    Applies XOR on the byte string given to mask method with the masking bytes
153    given to the constructor repeatedly. This object remembers the position
154    in the masking bytes the last mask method call ended and resumes from
155    that point on the next mask method call.
156    """
157    def __init__(self, masking_key):
158        self._masking_key = masking_key
159        self._masking_key_index = 0
160
161    def _mask_using_swig(self, s):
162        """Perform the mask via SWIG."""
163        masked_data = fast_masking.mask(s, self._masking_key,
164                                        self._masking_key_index)
165        self._masking_key_index = ((self._masking_key_index + len(s)) %
166                                   len(self._masking_key))
167        return masked_data
168
169    def _mask_using_array(self, s):
170        """Perform the mask via python."""
171        if isinstance(s, six.text_type):
172            raise Exception(
173                'Masking Operation should not process unicode strings')
174
175        result = bytearray(s)
176
177        # Use temporary local variables to eliminate the cost to access
178        # attributes
179        masking_key = [c for c in six.iterbytes(self._masking_key)]
180        masking_key_size = len(masking_key)
181        masking_key_index = self._masking_key_index
182
183        for i in range(len(result)):
184            result[i] ^= masking_key[masking_key_index]
185            masking_key_index = (masking_key_index + 1) % masking_key_size
186
187        self._masking_key_index = masking_key_index
188
189        return bytes(result)
190
191    if 'fast_masking' in globals():
192        mask = _mask_using_swig
193    else:
194        mask = _mask_using_array
195
196
197# By making wbits option negative, we can suppress CMF/FLG (2 octet) and
198# ADLER32 (4 octet) fields of zlib so that we can use zlib module just as
199# deflate library. DICTID won't be added as far as we don't set dictionary.
200# LZ77 window of 32K will be used for both compression and decompression.
201# For decompression, we can just use 32K to cover any windows size. For
202# compression, we use 32K so receivers must use 32K.
203#
204# Compression level is Z_DEFAULT_COMPRESSION. We don't have to match level
205# to decode.
206#
207# See zconf.h, deflate.cc, inflate.cc of zlib library, and zlibmodule.c of
208# Python. See also RFC1950 (ZLIB 3.3).
209
210
211class _Deflater(object):
212    def __init__(self, window_bits):
213        self._logger = get_class_logger(self)
214
215        # Using the smallest window bits of 9 for generating input frames.
216        # On WebSocket spec, the smallest window bit is 8. However, zlib does
217        # not accept window_bit = 8.
218        #
219        # Because of a zlib deflate quirk, back-references will not use the
220        # entire range of 1 << window_bits, but will instead use a restricted
221        # range of (1 << window_bits) - 262. With an increased window_bits = 9,
222        # back-references will be within a range of 250. These can still be
223        # decompressed with window_bits = 8 and the 256-byte window used there.
224        #
225        # Similar disscussions can be found in https://crbug.com/691074
226        window_bits = max(window_bits, 9)
227
228        self._compress = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
229                                          zlib.DEFLATED, -window_bits)
230
231    def compress(self, bytes):
232        compressed_bytes = self._compress.compress(bytes)
233        self._logger.debug('Compress input %r', bytes)
234        self._logger.debug('Compress result %r', compressed_bytes)
235        return compressed_bytes
236
237    def compress_and_flush(self, bytes):
238        compressed_bytes = self._compress.compress(bytes)
239        compressed_bytes += self._compress.flush(zlib.Z_SYNC_FLUSH)
240        self._logger.debug('Compress input %r', bytes)
241        self._logger.debug('Compress result %r', compressed_bytes)
242        return compressed_bytes
243
244    def compress_and_finish(self, bytes):
245        compressed_bytes = self._compress.compress(bytes)
246        compressed_bytes += self._compress.flush(zlib.Z_FINISH)
247        self._logger.debug('Compress input %r', bytes)
248        self._logger.debug('Compress result %r', compressed_bytes)
249        return compressed_bytes
250
251
252class _Inflater(object):
253    def __init__(self, window_bits):
254        self._logger = get_class_logger(self)
255        self._window_bits = window_bits
256
257        self._unconsumed = b''
258
259        self.reset()
260
261    def decompress(self, size):
262        if not (size == -1 or size > 0):
263            raise Exception('size must be -1 or positive')
264
265        data = b''
266
267        while True:
268            data += self._decompress.decompress(self._unconsumed,
269                                                max(0, size - len(data)))
270            self._unconsumed = self._decompress.unconsumed_tail
271            if self._decompress.unused_data:
272                # Encountered a last block (i.e. a block with BFINAL = 1) and
273                # found a new stream (unused_data). We cannot use the same
274                # zlib.Decompress object for the new stream. Create a new
275                # Decompress object to decompress the new one.
276                #
277                # It's fine to ignore unconsumed_tail if unused_data is not
278                # empty.
279                self._unconsumed = self._decompress.unused_data
280                self.reset()
281                if size >= 0 and len(data) == size:
282                    # data is filled. Don't call decompress again.
283                    break
284                else:
285                    # Re-invoke Decompress.decompress to try to decompress all
286                    # available bytes before invoking read which blocks until
287                    # any new byte is available.
288                    continue
289            else:
290                # Here, since unused_data is empty, even if unconsumed_tail is
291                # not empty, bytes of requested length are already in data. We
292                # don't have to "continue" here.
293                break
294
295        if data:
296            self._logger.debug('Decompressed %r', data)
297        return data
298
299    def append(self, data):
300        self._logger.debug('Appended %r', data)
301        self._unconsumed += data
302
303    def reset(self):
304        self._logger.debug('Reset')
305        self._decompress = zlib.decompressobj(-self._window_bits)
306
307
308# Compresses/decompresses given octets using the method introduced in RFC1979.
309
310
311class _RFC1979Deflater(object):
312    """A compressor class that applies DEFLATE to given byte sequence and
313    flushes using the algorithm described in the RFC1979 section 2.1.
314    """
315    def __init__(self, window_bits, no_context_takeover):
316        self._deflater = None
317        if window_bits is None:
318            window_bits = zlib.MAX_WBITS
319        self._window_bits = window_bits
320        self._no_context_takeover = no_context_takeover
321
322    def filter(self, bytes, end=True, bfinal=False):
323        if self._deflater is None:
324            self._deflater = _Deflater(self._window_bits)
325
326        if bfinal:
327            result = self._deflater.compress_and_finish(bytes)
328            # Add a padding block with BFINAL = 0 and BTYPE = 0.
329            result = result + pack_byte(0)
330            self._deflater = None
331            return result
332
333        result = self._deflater.compress_and_flush(bytes)
334        if end:
335            # Strip last 4 octets which is LEN and NLEN field of a
336            # non-compressed block added for Z_SYNC_FLUSH.
337            result = result[:-4]
338
339        if self._no_context_takeover and end:
340            self._deflater = None
341
342        return result
343
344
345class _RFC1979Inflater(object):
346    """A decompressor class a la RFC1979.
347
348    A decompressor class for byte sequence compressed and flushed following
349    the algorithm described in the RFC1979 section 2.1.
350    """
351    def __init__(self, window_bits=zlib.MAX_WBITS):
352        self._inflater = _Inflater(window_bits)
353
354    def filter(self, bytes):
355        # Restore stripped LEN and NLEN field of a non-compressed block added
356        # for Z_SYNC_FLUSH.
357        self._inflater.append(bytes + b'\x00\x00\xff\xff')
358        return self._inflater.decompress(-1)
359
360
361class DeflateSocket(object):
362    """A wrapper class for socket object to intercept send and recv to perform
363    deflate compression and decompression transparently.
364    """
365
366    # Size of the buffer passed to recv to receive compressed data.
367    _RECV_SIZE = 4096
368
369    def __init__(self, socket):
370        self._socket = socket
371
372        self._logger = get_class_logger(self)
373
374        self._deflater = _Deflater(zlib.MAX_WBITS)
375        self._inflater = _Inflater(zlib.MAX_WBITS)
376
377    def recv(self, size):
378        """Receives data from the socket specified on the construction up
379        to the specified size. Once any data is available, returns it even
380        if it's smaller than the specified size.
381        """
382
383        # TODO(tyoshino): Allow call with size=0. It should block until any
384        # decompressed data is available.
385        if size <= 0:
386            raise Exception('Non-positive size passed')
387        while True:
388            data = self._inflater.decompress(size)
389            if len(data) != 0:
390                return data
391
392            read_data = self._socket.recv(DeflateSocket._RECV_SIZE)
393            if not read_data:
394                return b''
395            self._inflater.append(read_data)
396
397    def sendall(self, bytes):
398        self.send(bytes)
399
400    def send(self, bytes):
401        self._socket.sendall(self._deflater.compress_and_flush(bytes))
402        return len(bytes)
403
404
405# vi:sts=4 sw=4 et
406