1#!/usr/bin/env python
2#
3# Copyright 2011, Google Inc.
4# All rights reserved.
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions are
8# met:
9#
10#     * Redistributions of source code must retain the above copyright
11# notice, this list of conditions and the following disclaimer.
12#     * Redistributions in binary form must reproduce the above
13# copyright notice, this list of conditions and the following disclaimer
14# in the documentation and/or other materials provided with the
15# distribution.
16#     * Neither the name of Google Inc. nor the names of its
17# contributors may be used to endorse or promote products derived from
18# this software without specific prior written permission.
19#
20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31"""Tests for handshake module."""
32
33from __future__ import absolute_import
34import unittest
35
36import set_sys_path  # Update sys.path to locate mod_pywebsocket module.
37from mod_pywebsocket import common
38from mod_pywebsocket.handshake._base import AbortedByUserException
39from mod_pywebsocket.handshake._base import HandshakeException
40from mod_pywebsocket.handshake._base import VersionException
41from mod_pywebsocket.handshake.hybi import Handshaker
42
43from test import mock
44
45
46class RequestDefinition(object):
47    """A class for holding data for constructing opening handshake strings for
48    testing the opening handshake processor.
49    """
50    def __init__(self, method, uri, headers):
51        self.method = method
52        self.uri = uri
53        self.headers = headers
54
55
56def _create_good_request_def():
57    return RequestDefinition(
58        'GET', '/demo', {
59            'Host': 'server.example.com',
60            'Upgrade': 'websocket',
61            'Connection': 'Upgrade',
62            'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
63            'Sec-WebSocket-Version': '13',
64            'Origin': 'http://example.com'
65        })
66
67
68def _create_request(request_def):
69    conn = mock.MockConn(b'')
70    return mock.MockRequest(method=request_def.method,
71                            uri=request_def.uri,
72                            headers_in=request_def.headers,
73                            connection=conn)
74
75
76def _create_handshaker(request):
77    handshaker = Handshaker(request, mock.MockDispatcher())
78    return handshaker
79
80
81class SubprotocolChoosingDispatcher(object):
82    """A dispatcher for testing. This dispatcher sets the i-th subprotocol
83    of requested ones to ws_protocol where i is given on construction as index
84    argument. If index is negative, default_value will be set to ws_protocol.
85    """
86    def __init__(self, index, default_value=None):
87        self.index = index
88        self.default_value = default_value
89
90    def do_extra_handshake(self, conn_context):
91        if self.index >= 0:
92            conn_context.ws_protocol = conn_context.ws_requested_protocols[
93                self.index]
94        else:
95            conn_context.ws_protocol = self.default_value
96
97    def transfer_data(self, conn_context):
98        pass
99
100
101class HandshakeAbortedException(Exception):
102    pass
103
104
105class AbortingDispatcher(object):
106    """A dispatcher for testing. This dispatcher raises an exception in
107    do_extra_handshake to reject the request.
108    """
109    def do_extra_handshake(self, conn_context):
110        raise HandshakeAbortedException('An exception to reject the request')
111
112    def transfer_data(self, conn_context):
113        pass
114
115
116class AbortedByUserDispatcher(object):
117    """A dispatcher for testing. This dispatcher raises an
118    AbortedByUserException in do_extra_handshake to reject the request.
119    """
120    def do_extra_handshake(self, conn_context):
121        raise AbortedByUserException('An AbortedByUserException to reject the '
122                                     'request')
123
124    def transfer_data(self, conn_context):
125        pass
126
127
128_EXPECTED_RESPONSE = (
129    b'HTTP/1.1 101 Switching Protocols\r\n'
130    b'Upgrade: websocket\r\n'
131    b'Connection: Upgrade\r\n'
132    b'Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n')
133
134
135class HandshakerTest(unittest.TestCase):
136    """A unittest for draft-ietf-hybi-thewebsocketprotocol-06 and later
137    handshake processor.
138    """
139    def test_do_handshake(self):
140        request = _create_request(_create_good_request_def())
141        dispatcher = mock.MockDispatcher()
142        handshaker = Handshaker(request, dispatcher)
143        handshaker.do_handshake()
144
145        self.assertTrue(dispatcher.do_extra_handshake_called)
146
147        self.assertEqual(_EXPECTED_RESPONSE, request.connection.written_data())
148        self.assertEqual('/demo', request.ws_resource)
149        self.assertEqual('http://example.com', request.ws_origin)
150        self.assertEqual(None, request.ws_protocol)
151        self.assertEqual(None, request.ws_extensions)
152        self.assertEqual(common.VERSION_HYBI_LATEST, request.ws_version)
153
154    def test_do_handshake_with_extra_headers(self):
155        request_def = _create_good_request_def()
156        # Add headers not related to WebSocket opening handshake.
157        request_def.headers['FooKey'] = 'BarValue'
158        request_def.headers['EmptyKey'] = ''
159
160        request = _create_request(request_def)
161        handshaker = _create_handshaker(request)
162        handshaker.do_handshake()
163        self.assertEqual(_EXPECTED_RESPONSE, request.connection.written_data())
164
165    def test_do_handshake_with_capitalized_value(self):
166        request_def = _create_good_request_def()
167        request_def.headers['upgrade'] = 'WEBSOCKET'
168
169        request = _create_request(request_def)
170        handshaker = _create_handshaker(request)
171        handshaker.do_handshake()
172        self.assertEqual(_EXPECTED_RESPONSE, request.connection.written_data())
173
174        request_def = _create_good_request_def()
175        request_def.headers['Connection'] = 'UPGRADE'
176
177        request = _create_request(request_def)
178        handshaker = _create_handshaker(request)
179        handshaker.do_handshake()
180        self.assertEqual(_EXPECTED_RESPONSE, request.connection.written_data())
181
182    def test_do_handshake_with_multiple_connection_values(self):
183        request_def = _create_good_request_def()
184        request_def.headers['Connection'] = 'Upgrade, keep-alive, , '
185
186        request = _create_request(request_def)
187        handshaker = _create_handshaker(request)
188        handshaker.do_handshake()
189        self.assertEqual(_EXPECTED_RESPONSE, request.connection.written_data())
190
191    def test_aborting_handshake(self):
192        handshaker = Handshaker(_create_request(_create_good_request_def()),
193                                AbortingDispatcher())
194        # do_extra_handshake raises an exception. Check that it's not caught by
195        # do_handshake.
196        self.assertRaises(HandshakeAbortedException, handshaker.do_handshake)
197
198    def test_do_handshake_with_protocol(self):
199        request_def = _create_good_request_def()
200        request_def.headers['Sec-WebSocket-Protocol'] = 'chat, superchat'
201
202        request = _create_request(request_def)
203        handshaker = Handshaker(request, SubprotocolChoosingDispatcher(0))
204        handshaker.do_handshake()
205
206        EXPECTED_RESPONSE = (
207            b'HTTP/1.1 101 Switching Protocols\r\n'
208            b'Upgrade: websocket\r\n'
209            b'Connection: Upgrade\r\n'
210            b'Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n'
211            b'Sec-WebSocket-Protocol: chat\r\n\r\n')
212
213        self.assertEqual(EXPECTED_RESPONSE, request.connection.written_data())
214        self.assertEqual('chat', request.ws_protocol)
215
216    def test_do_handshake_protocol_not_in_request_but_in_response(self):
217        request_def = _create_good_request_def()
218        request = _create_request(request_def)
219        handshaker = Handshaker(request,
220                                SubprotocolChoosingDispatcher(-1, 'foobar'))
221        # No request has been made but ws_protocol is set. HandshakeException
222        # must be raised.
223        self.assertRaises(HandshakeException, handshaker.do_handshake)
224
225    def test_do_handshake_with_protocol_no_protocol_selection(self):
226        request_def = _create_good_request_def()
227        request_def.headers['Sec-WebSocket-Protocol'] = 'chat, superchat'
228
229        request = _create_request(request_def)
230        handshaker = _create_handshaker(request)
231        # ws_protocol is not set. HandshakeException must be raised.
232        self.assertRaises(HandshakeException, handshaker.do_handshake)
233
234    def test_do_handshake_with_extensions(self):
235        request_def = _create_good_request_def()
236        request_def.headers['Sec-WebSocket-Extensions'] = (
237            'permessage-deflate; server_no_context_takeover')
238
239        EXPECTED_RESPONSE = (
240            b'HTTP/1.1 101 Switching Protocols\r\n'
241            b'Upgrade: websocket\r\n'
242            b'Connection: Upgrade\r\n'
243            b'Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n'
244            b'Sec-WebSocket-Extensions: '
245            b'permessage-deflate; server_no_context_takeover\r\n'
246            b'\r\n')
247
248        request = _create_request(request_def)
249        handshaker = _create_handshaker(request)
250        handshaker.do_handshake()
251        self.assertEqual(EXPECTED_RESPONSE, request.connection.written_data())
252        self.assertEqual(1, len(request.ws_extensions))
253        extension = request.ws_extensions[0]
254        self.assertEqual(common.PERMESSAGE_DEFLATE_EXTENSION, extension.name())
255        self.assertEqual(['server_no_context_takeover'],
256                         extension.get_parameter_names())
257        self.assertEqual(
258            None, extension.get_parameter_value('server_no_context_takeover'))
259        self.assertEqual(1, len(request.ws_extension_processors))
260        self.assertEqual('deflate', request.ws_extension_processors[0].name())
261
262    def test_do_handshake_with_quoted_extensions(self):
263        request_def = _create_good_request_def()
264        request_def.headers['Sec-WebSocket-Extensions'] = (
265            'permessage-deflate, , '
266            'unknown; e   =    "mc^2"; ma="\r\n      \\\rf  "; pv=nrt')
267
268        request = _create_request(request_def)
269        handshaker = _create_handshaker(request)
270        handshaker.do_handshake()
271        self.assertEqual(2, len(request.ws_requested_extensions))
272        first_extension = request.ws_requested_extensions[0]
273        self.assertEqual('permessage-deflate', first_extension.name())
274        second_extension = request.ws_requested_extensions[1]
275        self.assertEqual('unknown', second_extension.name())
276        self.assertEqual(['e', 'ma', 'pv'],
277                         second_extension.get_parameter_names())
278        self.assertEqual('mc^2', second_extension.get_parameter_value('e'))
279        self.assertEqual(' \rf ', second_extension.get_parameter_value('ma'))
280        self.assertEqual('nrt', second_extension.get_parameter_value('pv'))
281
282    def test_do_handshake_with_optional_headers(self):
283        request_def = _create_good_request_def()
284        request_def.headers['EmptyValue'] = ''
285        request_def.headers['AKey'] = 'AValue'
286
287        request = _create_request(request_def)
288        handshaker = _create_handshaker(request)
289        handshaker.do_handshake()
290        self.assertEqual('AValue', request.headers_in['AKey'])
291        self.assertEqual('', request.headers_in['EmptyValue'])
292
293    def test_abort_extra_handshake(self):
294        handshaker = Handshaker(_create_request(_create_good_request_def()),
295                                AbortedByUserDispatcher())
296        # do_extra_handshake raises an AbortedByUserException. Check that it's
297        # not caught by do_handshake.
298        self.assertRaises(AbortedByUserException, handshaker.do_handshake)
299
300    def test_bad_requests(self):
301        bad_cases = [
302            ('HTTP request',
303             RequestDefinition(
304                 'GET', '/demo', {
305                     'Host':
306                     'www.google.com',
307                     'User-Agent':
308                     'Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10.5;'
309                     ' en-US; rv:1.9.1.3) Gecko/20090824 Firefox/3.5.3'
310                     ' GTB6 GTBA',
311                     'Accept':
312                     'text/html,application/xhtml+xml,application/xml;q=0.9,'
313                     '*/*;q=0.8',
314                     'Accept-Language':
315                     'en-us,en;q=0.5',
316                     'Accept-Encoding':
317                     'gzip,deflate',
318                     'Accept-Charset':
319                     'ISO-8859-1,utf-8;q=0.7,*;q=0.7',
320                     'Keep-Alive':
321                     '300',
322                     'Connection':
323                     'keep-alive'
324                 }), None, True)
325        ]
326
327        request_def = _create_good_request_def()
328        request_def.method = 'POST'
329        bad_cases.append(('Wrong method', request_def, None, True))
330
331        request_def = _create_good_request_def()
332        del request_def.headers['Host']
333        bad_cases.append(('Missing Host', request_def, None, True))
334
335        request_def = _create_good_request_def()
336        del request_def.headers['Upgrade']
337        bad_cases.append(('Missing Upgrade', request_def, None, True))
338
339        request_def = _create_good_request_def()
340        request_def.headers['Upgrade'] = 'nonwebsocket'
341        bad_cases.append(('Wrong Upgrade', request_def, None, True))
342
343        request_def = _create_good_request_def()
344        del request_def.headers['Connection']
345        bad_cases.append(('Missing Connection', request_def, None, True))
346
347        request_def = _create_good_request_def()
348        request_def.headers['Connection'] = 'Downgrade'
349        bad_cases.append(('Wrong Connection', request_def, None, True))
350
351        request_def = _create_good_request_def()
352        del request_def.headers['Sec-WebSocket-Key']
353        bad_cases.append(('Missing Sec-WebSocket-Key', request_def, 400, True))
354
355        request_def = _create_good_request_def()
356        request_def.headers['Sec-WebSocket-Key'] = (
357            'dGhlIHNhbXBsZSBub25jZQ==garbage')
358        bad_cases.append(('Wrong Sec-WebSocket-Key (with garbage on the tail)',
359                          request_def, 400, True))
360
361        request_def = _create_good_request_def()
362        request_def.headers['Sec-WebSocket-Key'] = 'YQ=='  # BASE64 of 'a'
363        bad_cases.append(
364            ('Wrong Sec-WebSocket-Key (decoded value is not 16 octets long)',
365             request_def, 400, True))
366
367        request_def = _create_good_request_def()
368        # The last character right before == must be any of A, Q, w and g.
369        request_def.headers['Sec-WebSocket-Key'] = 'AQIDBAUGBwgJCgsMDQ4PEC=='
370        bad_cases.append(
371            ('Wrong Sec-WebSocket-Key (padding bits are not zero)',
372             request_def, 400, True))
373
374        request_def = _create_good_request_def()
375        request_def.headers['Sec-WebSocket-Key'] = (
376            'dGhlIHNhbXBsZSBub25jZQ==,dGhlIHNhbXBsZSBub25jZQ==')
377        bad_cases.append(('Wrong Sec-WebSocket-Key (multiple values)',
378                          request_def, 400, True))
379
380        request_def = _create_good_request_def()
381        del request_def.headers['Sec-WebSocket-Version']
382        bad_cases.append(
383            ('Missing Sec-WebSocket-Version', request_def, None, True))
384
385        request_def = _create_good_request_def()
386        request_def.headers['Sec-WebSocket-Version'] = '3'
387        bad_cases.append(
388            ('Wrong Sec-WebSocket-Version', request_def, None, False))
389
390        request_def = _create_good_request_def()
391        request_def.headers['Sec-WebSocket-Version'] = '13, 13'
392        bad_cases.append(('Wrong Sec-WebSocket-Version (multiple values)',
393                          request_def, 400, True))
394
395        request_def = _create_good_request_def()
396        request_def.headers['Sec-WebSocket-Protocol'] = 'illegal\x09protocol'
397        bad_cases.append(
398            ('Illegal Sec-WebSocket-Protocol', request_def, 400, True))
399
400        request_def = _create_good_request_def()
401        request_def.headers['Sec-WebSocket-Protocol'] = ''
402        bad_cases.append(
403            ('Empty Sec-WebSocket-Protocol', request_def, 400, True))
404
405        for (case_name, request_def, expected_status,
406             expect_handshake_exception) in bad_cases:
407            request = _create_request(request_def)
408            handshaker = Handshaker(request, mock.MockDispatcher())
409            try:
410                handshaker.do_handshake()
411                self.fail('No exception thrown for \'%s\' case' % case_name)
412            except HandshakeException as e:
413                self.assertTrue(expect_handshake_exception)
414                self.assertEqual(expected_status, e.status)
415            except VersionException as e:
416                self.assertFalse(expect_handshake_exception)
417
418
419if __name__ == '__main__':
420    unittest.main()
421
422# vi:sts=4 sw=4 et
423