1#!/usr/bin/env python 2# 3# Copyright 2011, Google Inc. 4# All rights reserved. 5# 6# Redistribution and use in source and binary forms, with or without 7# modification, are permitted provided that the following conditions are 8# met: 9# 10# * Redistributions of source code must retain the above copyright 11# notice, this list of conditions and the following disclaimer. 12# * Redistributions in binary form must reproduce the above 13# copyright notice, this list of conditions and the following disclaimer 14# in the documentation and/or other materials provided with the 15# distribution. 16# * Neither the name of Google Inc. nor the names of its 17# contributors may be used to endorse or promote products derived from 18# this software without specific prior written permission. 19# 20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31"""Tests for handshake module.""" 32 33from __future__ import absolute_import 34import unittest 35 36import set_sys_path # Update sys.path to locate mod_pywebsocket module. 37from mod_pywebsocket import common 38from mod_pywebsocket.handshake._base import AbortedByUserException 39from mod_pywebsocket.handshake._base import HandshakeException 40from mod_pywebsocket.handshake._base import VersionException 41from mod_pywebsocket.handshake.hybi import Handshaker 42 43from test import mock 44 45 46class RequestDefinition(object): 47 """A class for holding data for constructing opening handshake strings for 48 testing the opening handshake processor. 49 """ 50 def __init__(self, method, uri, headers): 51 self.method = method 52 self.uri = uri 53 self.headers = headers 54 55 56def _create_good_request_def(): 57 return RequestDefinition( 58 'GET', '/demo', { 59 'Host': 'server.example.com', 60 'Upgrade': 'websocket', 61 'Connection': 'Upgrade', 62 'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==', 63 'Sec-WebSocket-Version': '13', 64 'Origin': 'http://example.com' 65 }) 66 67 68def _create_request(request_def): 69 conn = mock.MockConn(b'') 70 return mock.MockRequest(method=request_def.method, 71 uri=request_def.uri, 72 headers_in=request_def.headers, 73 connection=conn) 74 75 76def _create_handshaker(request): 77 handshaker = Handshaker(request, mock.MockDispatcher()) 78 return handshaker 79 80 81class SubprotocolChoosingDispatcher(object): 82 """A dispatcher for testing. This dispatcher sets the i-th subprotocol 83 of requested ones to ws_protocol where i is given on construction as index 84 argument. If index is negative, default_value will be set to ws_protocol. 85 """ 86 def __init__(self, index, default_value=None): 87 self.index = index 88 self.default_value = default_value 89 90 def do_extra_handshake(self, conn_context): 91 if self.index >= 0: 92 conn_context.ws_protocol = conn_context.ws_requested_protocols[ 93 self.index] 94 else: 95 conn_context.ws_protocol = self.default_value 96 97 def transfer_data(self, conn_context): 98 pass 99 100 101class HandshakeAbortedException(Exception): 102 pass 103 104 105class AbortingDispatcher(object): 106 """A dispatcher for testing. This dispatcher raises an exception in 107 do_extra_handshake to reject the request. 108 """ 109 def do_extra_handshake(self, conn_context): 110 raise HandshakeAbortedException('An exception to reject the request') 111 112 def transfer_data(self, conn_context): 113 pass 114 115 116class AbortedByUserDispatcher(object): 117 """A dispatcher for testing. This dispatcher raises an 118 AbortedByUserException in do_extra_handshake to reject the request. 119 """ 120 def do_extra_handshake(self, conn_context): 121 raise AbortedByUserException('An AbortedByUserException to reject the ' 122 'request') 123 124 def transfer_data(self, conn_context): 125 pass 126 127 128_EXPECTED_RESPONSE = ( 129 b'HTTP/1.1 101 Switching Protocols\r\n' 130 b'Upgrade: websocket\r\n' 131 b'Connection: Upgrade\r\n' 132 b'Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n') 133 134 135class HandshakerTest(unittest.TestCase): 136 """A unittest for draft-ietf-hybi-thewebsocketprotocol-06 and later 137 handshake processor. 138 """ 139 def test_do_handshake(self): 140 request = _create_request(_create_good_request_def()) 141 dispatcher = mock.MockDispatcher() 142 handshaker = Handshaker(request, dispatcher) 143 handshaker.do_handshake() 144 145 self.assertTrue(dispatcher.do_extra_handshake_called) 146 147 self.assertEqual(_EXPECTED_RESPONSE, request.connection.written_data()) 148 self.assertEqual('/demo', request.ws_resource) 149 self.assertEqual('http://example.com', request.ws_origin) 150 self.assertEqual(None, request.ws_protocol) 151 self.assertEqual(None, request.ws_extensions) 152 self.assertEqual(common.VERSION_HYBI_LATEST, request.ws_version) 153 154 def test_do_handshake_with_extra_headers(self): 155 request_def = _create_good_request_def() 156 # Add headers not related to WebSocket opening handshake. 157 request_def.headers['FooKey'] = 'BarValue' 158 request_def.headers['EmptyKey'] = '' 159 160 request = _create_request(request_def) 161 handshaker = _create_handshaker(request) 162 handshaker.do_handshake() 163 self.assertEqual(_EXPECTED_RESPONSE, request.connection.written_data()) 164 165 def test_do_handshake_with_capitalized_value(self): 166 request_def = _create_good_request_def() 167 request_def.headers['upgrade'] = 'WEBSOCKET' 168 169 request = _create_request(request_def) 170 handshaker = _create_handshaker(request) 171 handshaker.do_handshake() 172 self.assertEqual(_EXPECTED_RESPONSE, request.connection.written_data()) 173 174 request_def = _create_good_request_def() 175 request_def.headers['Connection'] = 'UPGRADE' 176 177 request = _create_request(request_def) 178 handshaker = _create_handshaker(request) 179 handshaker.do_handshake() 180 self.assertEqual(_EXPECTED_RESPONSE, request.connection.written_data()) 181 182 def test_do_handshake_with_multiple_connection_values(self): 183 request_def = _create_good_request_def() 184 request_def.headers['Connection'] = 'Upgrade, keep-alive, , ' 185 186 request = _create_request(request_def) 187 handshaker = _create_handshaker(request) 188 handshaker.do_handshake() 189 self.assertEqual(_EXPECTED_RESPONSE, request.connection.written_data()) 190 191 def test_aborting_handshake(self): 192 handshaker = Handshaker(_create_request(_create_good_request_def()), 193 AbortingDispatcher()) 194 # do_extra_handshake raises an exception. Check that it's not caught by 195 # do_handshake. 196 self.assertRaises(HandshakeAbortedException, handshaker.do_handshake) 197 198 def test_do_handshake_with_protocol(self): 199 request_def = _create_good_request_def() 200 request_def.headers['Sec-WebSocket-Protocol'] = 'chat, superchat' 201 202 request = _create_request(request_def) 203 handshaker = Handshaker(request, SubprotocolChoosingDispatcher(0)) 204 handshaker.do_handshake() 205 206 EXPECTED_RESPONSE = ( 207 b'HTTP/1.1 101 Switching Protocols\r\n' 208 b'Upgrade: websocket\r\n' 209 b'Connection: Upgrade\r\n' 210 b'Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n' 211 b'Sec-WebSocket-Protocol: chat\r\n\r\n') 212 213 self.assertEqual(EXPECTED_RESPONSE, request.connection.written_data()) 214 self.assertEqual('chat', request.ws_protocol) 215 216 def test_do_handshake_protocol_not_in_request_but_in_response(self): 217 request_def = _create_good_request_def() 218 request = _create_request(request_def) 219 handshaker = Handshaker(request, 220 SubprotocolChoosingDispatcher(-1, 'foobar')) 221 # No request has been made but ws_protocol is set. HandshakeException 222 # must be raised. 223 self.assertRaises(HandshakeException, handshaker.do_handshake) 224 225 def test_do_handshake_with_protocol_no_protocol_selection(self): 226 request_def = _create_good_request_def() 227 request_def.headers['Sec-WebSocket-Protocol'] = 'chat, superchat' 228 229 request = _create_request(request_def) 230 handshaker = _create_handshaker(request) 231 # ws_protocol is not set. HandshakeException must be raised. 232 self.assertRaises(HandshakeException, handshaker.do_handshake) 233 234 def test_do_handshake_with_extensions(self): 235 request_def = _create_good_request_def() 236 request_def.headers['Sec-WebSocket-Extensions'] = ( 237 'permessage-deflate; server_no_context_takeover') 238 239 EXPECTED_RESPONSE = ( 240 b'HTTP/1.1 101 Switching Protocols\r\n' 241 b'Upgrade: websocket\r\n' 242 b'Connection: Upgrade\r\n' 243 b'Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n' 244 b'Sec-WebSocket-Extensions: ' 245 b'permessage-deflate; server_no_context_takeover\r\n' 246 b'\r\n') 247 248 request = _create_request(request_def) 249 handshaker = _create_handshaker(request) 250 handshaker.do_handshake() 251 self.assertEqual(EXPECTED_RESPONSE, request.connection.written_data()) 252 self.assertEqual(1, len(request.ws_extensions)) 253 extension = request.ws_extensions[0] 254 self.assertEqual(common.PERMESSAGE_DEFLATE_EXTENSION, extension.name()) 255 self.assertEqual(['server_no_context_takeover'], 256 extension.get_parameter_names()) 257 self.assertEqual( 258 None, extension.get_parameter_value('server_no_context_takeover')) 259 self.assertEqual(1, len(request.ws_extension_processors)) 260 self.assertEqual('deflate', request.ws_extension_processors[0].name()) 261 262 def test_do_handshake_with_quoted_extensions(self): 263 request_def = _create_good_request_def() 264 request_def.headers['Sec-WebSocket-Extensions'] = ( 265 'permessage-deflate, , ' 266 'unknown; e = "mc^2"; ma="\r\n \\\rf "; pv=nrt') 267 268 request = _create_request(request_def) 269 handshaker = _create_handshaker(request) 270 handshaker.do_handshake() 271 self.assertEqual(2, len(request.ws_requested_extensions)) 272 first_extension = request.ws_requested_extensions[0] 273 self.assertEqual('permessage-deflate', first_extension.name()) 274 second_extension = request.ws_requested_extensions[1] 275 self.assertEqual('unknown', second_extension.name()) 276 self.assertEqual(['e', 'ma', 'pv'], 277 second_extension.get_parameter_names()) 278 self.assertEqual('mc^2', second_extension.get_parameter_value('e')) 279 self.assertEqual(' \rf ', second_extension.get_parameter_value('ma')) 280 self.assertEqual('nrt', second_extension.get_parameter_value('pv')) 281 282 def test_do_handshake_with_optional_headers(self): 283 request_def = _create_good_request_def() 284 request_def.headers['EmptyValue'] = '' 285 request_def.headers['AKey'] = 'AValue' 286 287 request = _create_request(request_def) 288 handshaker = _create_handshaker(request) 289 handshaker.do_handshake() 290 self.assertEqual('AValue', request.headers_in['AKey']) 291 self.assertEqual('', request.headers_in['EmptyValue']) 292 293 def test_abort_extra_handshake(self): 294 handshaker = Handshaker(_create_request(_create_good_request_def()), 295 AbortedByUserDispatcher()) 296 # do_extra_handshake raises an AbortedByUserException. Check that it's 297 # not caught by do_handshake. 298 self.assertRaises(AbortedByUserException, handshaker.do_handshake) 299 300 def test_bad_requests(self): 301 bad_cases = [ 302 ('HTTP request', 303 RequestDefinition( 304 'GET', '/demo', { 305 'Host': 306 'www.google.com', 307 'User-Agent': 308 'Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10.5;' 309 ' en-US; rv:1.9.1.3) Gecko/20090824 Firefox/3.5.3' 310 ' GTB6 GTBA', 311 'Accept': 312 'text/html,application/xhtml+xml,application/xml;q=0.9,' 313 '*/*;q=0.8', 314 'Accept-Language': 315 'en-us,en;q=0.5', 316 'Accept-Encoding': 317 'gzip,deflate', 318 'Accept-Charset': 319 'ISO-8859-1,utf-8;q=0.7,*;q=0.7', 320 'Keep-Alive': 321 '300', 322 'Connection': 323 'keep-alive' 324 }), None, True) 325 ] 326 327 request_def = _create_good_request_def() 328 request_def.method = 'POST' 329 bad_cases.append(('Wrong method', request_def, None, True)) 330 331 request_def = _create_good_request_def() 332 del request_def.headers['Host'] 333 bad_cases.append(('Missing Host', request_def, None, True)) 334 335 request_def = _create_good_request_def() 336 del request_def.headers['Upgrade'] 337 bad_cases.append(('Missing Upgrade', request_def, None, True)) 338 339 request_def = _create_good_request_def() 340 request_def.headers['Upgrade'] = 'nonwebsocket' 341 bad_cases.append(('Wrong Upgrade', request_def, None, True)) 342 343 request_def = _create_good_request_def() 344 del request_def.headers['Connection'] 345 bad_cases.append(('Missing Connection', request_def, None, True)) 346 347 request_def = _create_good_request_def() 348 request_def.headers['Connection'] = 'Downgrade' 349 bad_cases.append(('Wrong Connection', request_def, None, True)) 350 351 request_def = _create_good_request_def() 352 del request_def.headers['Sec-WebSocket-Key'] 353 bad_cases.append(('Missing Sec-WebSocket-Key', request_def, 400, True)) 354 355 request_def = _create_good_request_def() 356 request_def.headers['Sec-WebSocket-Key'] = ( 357 'dGhlIHNhbXBsZSBub25jZQ==garbage') 358 bad_cases.append(('Wrong Sec-WebSocket-Key (with garbage on the tail)', 359 request_def, 400, True)) 360 361 request_def = _create_good_request_def() 362 request_def.headers['Sec-WebSocket-Key'] = 'YQ==' # BASE64 of 'a' 363 bad_cases.append( 364 ('Wrong Sec-WebSocket-Key (decoded value is not 16 octets long)', 365 request_def, 400, True)) 366 367 request_def = _create_good_request_def() 368 # The last character right before == must be any of A, Q, w and g. 369 request_def.headers['Sec-WebSocket-Key'] = 'AQIDBAUGBwgJCgsMDQ4PEC==' 370 bad_cases.append( 371 ('Wrong Sec-WebSocket-Key (padding bits are not zero)', 372 request_def, 400, True)) 373 374 request_def = _create_good_request_def() 375 request_def.headers['Sec-WebSocket-Key'] = ( 376 'dGhlIHNhbXBsZSBub25jZQ==,dGhlIHNhbXBsZSBub25jZQ==') 377 bad_cases.append(('Wrong Sec-WebSocket-Key (multiple values)', 378 request_def, 400, True)) 379 380 request_def = _create_good_request_def() 381 del request_def.headers['Sec-WebSocket-Version'] 382 bad_cases.append( 383 ('Missing Sec-WebSocket-Version', request_def, None, True)) 384 385 request_def = _create_good_request_def() 386 request_def.headers['Sec-WebSocket-Version'] = '3' 387 bad_cases.append( 388 ('Wrong Sec-WebSocket-Version', request_def, None, False)) 389 390 request_def = _create_good_request_def() 391 request_def.headers['Sec-WebSocket-Version'] = '13, 13' 392 bad_cases.append(('Wrong Sec-WebSocket-Version (multiple values)', 393 request_def, 400, True)) 394 395 request_def = _create_good_request_def() 396 request_def.headers['Sec-WebSocket-Protocol'] = 'illegal\x09protocol' 397 bad_cases.append( 398 ('Illegal Sec-WebSocket-Protocol', request_def, 400, True)) 399 400 request_def = _create_good_request_def() 401 request_def.headers['Sec-WebSocket-Protocol'] = '' 402 bad_cases.append( 403 ('Empty Sec-WebSocket-Protocol', request_def, 400, True)) 404 405 for (case_name, request_def, expected_status, 406 expect_handshake_exception) in bad_cases: 407 request = _create_request(request_def) 408 handshaker = Handshaker(request, mock.MockDispatcher()) 409 try: 410 handshaker.do_handshake() 411 self.fail('No exception thrown for \'%s\' case' % case_name) 412 except HandshakeException as e: 413 self.assertTrue(expect_handshake_exception) 414 self.assertEqual(expected_status, e.status) 415 except VersionException as e: 416 self.assertFalse(expect_handshake_exception) 417 418 419if __name__ == '__main__': 420 unittest.main() 421 422# vi:sts=4 sw=4 et 423