1#!/usr/bin/env python 2# 3# Copyright 2011, Google Inc. 4# All rights reserved. 5# 6# Redistribution and use in source and binary forms, with or without 7# modification, are permitted provided that the following conditions are 8# met: 9# 10# * Redistributions of source code must retain the above copyright 11# notice, this list of conditions and the following disclaimer. 12# * Redistributions in binary form must reproduce the above 13# copyright notice, this list of conditions and the following disclaimer 14# in the documentation and/or other materials provided with the 15# distribution. 16# * Neither the name of Google Inc. nor the names of its 17# contributors may be used to endorse or promote products derived from 18# this software without specific prior written permission. 19# 20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31"""Tests for util module.""" 32 33from __future__ import absolute_import 34from __future__ import print_function 35import os 36import random 37import sys 38import unittest 39import struct 40 41import set_sys_path # Update sys.path to locate mod_pywebsocket module. 42 43from mod_pywebsocket import util 44from six.moves import range 45from six import PY3 46from six import int2byte 47 48_TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), 'testdata') 49 50 51class UtilTest(unittest.TestCase): 52 """A unittest for util module.""" 53 def test_prepend_message_to_exception(self): 54 exc = Exception('World') 55 self.assertEqual('World', str(exc)) 56 util.prepend_message_to_exception('Hello ', exc) 57 self.assertEqual('Hello World', str(exc)) 58 59 def test_get_script_interp(self): 60 cygwin_path = 'c:\\cygwin\\bin' 61 cygwin_perl = os.path.join(cygwin_path, 'perl') 62 self.assertEqual( 63 None, util.get_script_interp(os.path.join(_TEST_DATA_DIR, 64 'README'))) 65 self.assertEqual( 66 None, 67 util.get_script_interp(os.path.join(_TEST_DATA_DIR, 'README'), 68 cygwin_path)) 69 self.assertEqual( 70 '/usr/bin/perl -wT', 71 util.get_script_interp(os.path.join(_TEST_DATA_DIR, 'hello.pl'))) 72 self.assertEqual( 73 cygwin_perl + ' -wT', 74 util.get_script_interp(os.path.join(_TEST_DATA_DIR, 'hello.pl'), 75 cygwin_path)) 76 77 def test_hexify(self): 78 self.assertEqual('61 7a 41 5a 30 39 20 09 0d 0a 00 ff', 79 util.hexify(b'azAZ09 \t\r\n\x00\xff')) 80 81 82class RepeatedXorMaskerTest(unittest.TestCase): 83 """A unittest for RepeatedXorMasker class.""" 84 def test_mask(self): 85 # Sample input e6,97,a5 is U+65e5 in UTF-8 86 masker = util.RepeatedXorMasker(b'\xff\xff\xff\xff') 87 result = masker.mask(b'\xe6\x97\xa5') 88 self.assertEqual(b'\x19\x68\x5a', result) 89 90 masker = util.RepeatedXorMasker(b'\x00\x00\x00\x00') 91 result = masker.mask(b'\xe6\x97\xa5') 92 self.assertEqual(b'\xe6\x97\xa5', result) 93 94 masker = util.RepeatedXorMasker(b'\xe6\x97\xa5\x20') 95 result = masker.mask(b'\xe6\x97\xa5') 96 self.assertEqual(b'\x00\x00\x00', result) 97 98 def test_mask_twice(self): 99 masker = util.RepeatedXorMasker(b'\x00\x7f\xff\x20') 100 # mask[0], mask[1], ... will be used. 101 result = masker.mask(b'\x00\x00\x00\x00\x00') 102 self.assertEqual(b'\x00\x7f\xff\x20\x00', result) 103 # mask[2], mask[0], ... will be used for the next call. 104 result = masker.mask(b'\x00\x00\x00\x00\x00') 105 self.assertEqual(b'\x7f\xff\x20\x00\x7f', result) 106 107 def test_mask_large_data(self): 108 masker = util.RepeatedXorMasker(b'mASk') 109 original = b''.join([util.pack_byte(i % 256) for i in range(1000)]) 110 result = masker.mask(original) 111 expected = b''.join([ 112 util.pack_byte((i % 256) ^ ord('mASk'[i % 4])) for i in range(1000) 113 ]) 114 self.assertEqual(expected, result) 115 116 masker = util.RepeatedXorMasker(b'MaSk') 117 first_part = b'The WebSocket Protocol enables two-way communication.' 118 result = masker.mask(first_part) 119 self.assertEqual( 120 b'\x19\t6K\x1a\x0418"\x028\x0e9A\x03\x19"\x15<\x08"\rs\x0e#' 121 b'\x001\x07(\x12s\x1f:\x0e~\x1c,\x18s\x08"\x0c>\x1e#\x080\n9' 122 b'\x08<\x05c', result) 123 second_part = b'It has two parts: a handshake and the data transfer.' 124 result = masker.mask(second_part) 125 self.assertEqual( 126 b"('K%\x00 K9\x16<K=\x00!\x1f>[s\nm\t2\x05)\x12;\n&\x04s\n#" 127 b"\x05s\x1f%\x04s\x0f,\x152K9\x132\x05>\x076\x19c", result) 128 129 130def get_random_section(source, min_num_chunks): 131 chunks = [] 132 bytes_chunked = 0 133 134 while bytes_chunked < len(source): 135 chunk_size = random.randint( 136 1, min(len(source) / min_num_chunks, 137 len(source) - bytes_chunked)) 138 chunk = source[bytes_chunked:bytes_chunked + chunk_size] 139 chunks.append(chunk) 140 bytes_chunked += chunk_size 141 142 return chunks 143 144 145class InflaterDeflaterTest(unittest.TestCase): 146 """A unittest for _Inflater and _Deflater class.""" 147 def test_inflate_deflate_default(self): 148 input = b'hello' + b'-' * 30000 + b'hello' 149 inflater15 = util._Inflater(15) 150 deflater15 = util._Deflater(15) 151 inflater8 = util._Inflater(8) 152 deflater8 = util._Deflater(8) 153 154 compressed15 = deflater15.compress_and_finish(input) 155 compressed8 = deflater8.compress_and_finish(input) 156 157 inflater15.append(compressed15) 158 inflater8.append(compressed8) 159 160 self.assertNotEqual(compressed15, compressed8) 161 self.assertEqual(input, inflater15.decompress(-1)) 162 self.assertEqual(input, inflater8.decompress(-1)) 163 164 def test_random_section(self): 165 random.seed(a=0) 166 source = b''.join( 167 [int2byte(random.randint(0, 255)) for i in range(100 * 1024)]) 168 169 chunked_input = get_random_section(source, 10) 170 171 deflater = util._Deflater(15) 172 compressed = [] 173 for chunk in chunked_input: 174 compressed.append(deflater.compress(chunk)) 175 compressed.append(deflater.compress_and_finish(b'')) 176 177 chunked_expectation = get_random_section(source, 10) 178 179 inflater = util._Inflater(15) 180 inflater.append(b''.join(compressed)) 181 for chunk in chunked_expectation: 182 decompressed = inflater.decompress(len(chunk)) 183 self.assertEqual(chunk, decompressed) 184 185 self.assertEqual(b'', inflater.decompress(-1)) 186 187 188if __name__ == '__main__': 189 unittest.main() 190 191# vi:sts=4 sw=4 et 192