1# vim: tabstop=4 shiftwidth=4 softtabstop=4
2
3# Copyright(c)2013 NTT corp. All Rights Reserved.
4#
5#    Licensed under the Apache License, Version 2.0 (the "License"); you may
6#    not use this file except in compliance with the License. You may obtain
7#    a copy of the License at
8#
9#         http://www.apache.org/licenses/LICENSE-2.0
10#
11#    Unless required by applicable law or agreed to in writing, software
12#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14#    License for the specific language governing permissions and limitations
15#    under the License.
16
17""" Unit tests for websockifyserver """
18import errno
19import os
20import logging
21import select
22import shutil
23import socket
24import ssl
25from mox3 import stubout
26import sys
27import tempfile
28import unittest
29import socket
30import signal
31from websockify import websockifyserver
32
33try:
34    from BaseHTTPServer import BaseHTTPRequestHandler
35except ImportError:
36    from http.server import BaseHTTPRequestHandler
37
38try:
39    from StringIO import StringIO
40    BytesIO = StringIO
41except ImportError:
42    from io import StringIO
43    from io import BytesIO
44
45
46
47
48def raise_oserror(*args, **kwargs):
49    raise OSError('fake error')
50
51
52class FakeSocket(object):
53    def __init__(self, data=''):
54        if isinstance(data, bytes):
55            self._data = data
56        else:
57            self._data = data.encode('latin_1')
58
59    def recv(self, amt, flags=None):
60        res = self._data[0:amt]
61        if not (flags & socket.MSG_PEEK):
62            self._data = self._data[amt:]
63
64        return res
65
66    def makefile(self, mode='r', buffsize=None):
67        if 'b' in mode:
68            return BytesIO(self._data)
69        else:
70            return StringIO(self._data.decode('latin_1'))
71
72
73class WebSockifyRequestHandlerTestCase(unittest.TestCase):
74    def setUp(self):
75        super(WebSockifyRequestHandlerTestCase, self).setUp()
76        self.stubs = stubout.StubOutForTesting()
77        self.tmpdir = tempfile.mkdtemp('-websockify-tests')
78        # Mock this out cause it screws tests up
79        self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None)
80        self.stubs.Set(BaseHTTPRequestHandler, 'send_response',
81                       lambda *args, **kwargs: None)
82
83        def fake_send_error(self, code, message=None, explain=None):
84            self.last_code = code
85
86        self.stubs.Set(BaseHTTPRequestHandler, 'send_error',
87                       fake_send_error)
88
89    def tearDown(self):
90        """Called automatically after each test."""
91        self.stubs.UnsetAll()
92        os.rmdir(self.tmpdir)
93        super(WebSockifyRequestHandlerTestCase, self).tearDown()
94
95    def _get_server(self, handler_class=websockifyserver.WebSockifyRequestHandler,
96                    **kwargs):
97        web = kwargs.pop('web', self.tmpdir)
98        return websockifyserver.WebSockifyServer(
99            handler_class, listen_host='localhost',
100            listen_port=80, key=self.tmpdir, web=web,
101            record=self.tmpdir, daemon=False, ssl_only=0, idle_timeout=1,
102            **kwargs)
103
104    def test_normal_get_with_only_upgrade_returns_error(self):
105        server = self._get_server(web=None)
106        handler = websockifyserver.WebSockifyRequestHandler(
107            FakeSocket('GET /tmp.txt HTTP/1.1'), '127.0.0.1', server)
108
109        def fake_send_response(self, code, message=None):
110            self.last_code = code
111
112        self.stubs.Set(BaseHTTPRequestHandler, 'send_response',
113                       fake_send_response)
114
115        handler.do_GET()
116        self.assertEqual(handler.last_code, 405)
117
118    def test_list_dir_with_file_only_returns_error(self):
119        server = self._get_server(file_only=True)
120        handler = websockifyserver.WebSockifyRequestHandler(
121            FakeSocket('GET / HTTP/1.1'), '127.0.0.1', server)
122
123        def fake_send_response(self, code, message=None):
124            self.last_code = code
125
126        self.stubs.Set(BaseHTTPRequestHandler, 'send_response',
127                       fake_send_response)
128
129        handler.path = '/'
130        handler.do_GET()
131        self.assertEqual(handler.last_code, 404)
132
133
134class WebSockifyServerTestCase(unittest.TestCase):
135    def setUp(self):
136        super(WebSockifyServerTestCase, self).setUp()
137        self.stubs = stubout.StubOutForTesting()
138        self.tmpdir = tempfile.mkdtemp('-websockify-tests')
139        # Mock this out cause it screws tests up
140        self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None)
141
142    def tearDown(self):
143        """Called automatically after each test."""
144        self.stubs.UnsetAll()
145        os.rmdir(self.tmpdir)
146        super(WebSockifyServerTestCase, self).tearDown()
147
148    def _get_server(self, handler_class=websockifyserver.WebSockifyRequestHandler,
149                    **kwargs):
150        return websockifyserver.WebSockifyServer(
151            handler_class, listen_host='localhost',
152            listen_port=80, key=self.tmpdir, web=self.tmpdir,
153            record=self.tmpdir, **kwargs)
154
155    def test_daemonize_raises_error_while_closing_fds(self):
156        server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
157        self.stubs.Set(os, 'fork', lambda *args: 0)
158        self.stubs.Set(signal, 'signal', lambda *args: None)
159        self.stubs.Set(os, 'setsid', lambda *args: None)
160        self.stubs.Set(os, 'close', raise_oserror)
161        self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
162
163    def test_daemonize_ignores_ebadf_error_while_closing_fds(self):
164        def raise_oserror_ebadf(fd):
165            raise OSError(errno.EBADF, 'fake error')
166
167        server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
168        self.stubs.Set(os, 'fork', lambda *args: 0)
169        self.stubs.Set(os, 'setsid', lambda *args: None)
170        self.stubs.Set(signal, 'signal', lambda *args: None)
171        self.stubs.Set(os, 'close', raise_oserror_ebadf)
172        self.stubs.Set(os, 'open', raise_oserror)
173        self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
174
175    def test_handshake_fails_on_not_ready(self):
176        server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
177
178        def fake_select(rlist, wlist, xlist, timeout=None):
179            return ([], [], [])
180
181        self.stubs.Set(select, 'select', fake_select)
182        self.assertRaises(
183            websockifyserver.WebSockifyServer.EClose, server.do_handshake,
184            FakeSocket(), '127.0.0.1')
185
186    def test_empty_handshake_fails(self):
187        server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
188
189        sock = FakeSocket('')
190
191        def fake_select(rlist, wlist, xlist, timeout=None):
192            return ([sock], [], [])
193
194        self.stubs.Set(select, 'select', fake_select)
195        self.assertRaises(
196            websockifyserver.WebSockifyServer.EClose, server.do_handshake,
197            sock, '127.0.0.1')
198
199    def test_handshake_policy_request(self):
200        # TODO(directxman12): implement
201        pass
202
203    def test_handshake_ssl_only_without_ssl_raises_error(self):
204        server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
205
206        sock = FakeSocket('some initial data')
207
208        def fake_select(rlist, wlist, xlist, timeout=None):
209            return ([sock], [], [])
210
211        self.stubs.Set(select, 'select', fake_select)
212        self.assertRaises(
213            websockifyserver.WebSockifyServer.EClose, server.do_handshake,
214            sock, '127.0.0.1')
215
216    def test_do_handshake_no_ssl(self):
217        class FakeHandler(object):
218            CALLED = False
219            def __init__(self, *args, **kwargs):
220                type(self).CALLED = True
221
222        FakeHandler.CALLED = False
223
224        server = self._get_server(
225            handler_class=FakeHandler, daemon=True,
226            ssl_only=0, idle_timeout=1)
227
228        sock = FakeSocket('some initial data')
229
230        def fake_select(rlist, wlist, xlist, timeout=None):
231            return ([sock], [], [])
232
233        self.stubs.Set(select, 'select', fake_select)
234        self.assertEqual(server.do_handshake(sock, '127.0.0.1'), sock)
235        self.assertTrue(FakeHandler.CALLED, True)
236
237    def test_do_handshake_ssl(self):
238        # TODO(directxman12): implement this
239        pass
240
241    def test_do_handshake_ssl_without_ssl_raises_error(self):
242        # TODO(directxman12): implement this
243        pass
244
245    def test_do_handshake_ssl_without_cert_raises_error(self):
246        server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1,
247                                  cert='afdsfasdafdsafdsafdsafdas')
248
249        sock = FakeSocket("\x16some ssl data")
250
251        def fake_select(rlist, wlist, xlist, timeout=None):
252            return ([sock], [], [])
253
254        self.stubs.Set(select, 'select', fake_select)
255        self.assertRaises(
256            websockifyserver.WebSockifyServer.EClose, server.do_handshake,
257            sock, '127.0.0.1')
258
259    def test_do_handshake_ssl_error_eof_raises_close_error(self):
260        server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
261
262        sock = FakeSocket("\x16some ssl data")
263
264        def fake_select(rlist, wlist, xlist, timeout=None):
265            return ([sock], [], [])
266
267        def fake_wrap_socket(*args, **kwargs):
268            raise ssl.SSLError(ssl.SSL_ERROR_EOF)
269
270        class fake_create_default_context():
271            def __init__(self, purpose):
272                self.verify_mode = None
273                self.options = 0
274            def load_cert_chain(self, certfile, keyfile, password):
275                pass
276            def set_default_verify_paths(self):
277                pass
278            def load_verify_locations(self, cafile):
279                pass
280            def wrap_socket(self, *args, **kwargs):
281                raise ssl.SSLError(ssl.SSL_ERROR_EOF)
282
283        self.stubs.Set(select, 'select', fake_select)
284        if (hasattr(ssl, 'create_default_context')):
285            # for recent versions of python
286            self.stubs.Set(ssl, 'create_default_context', fake_create_default_context)
287        else:
288            # for fallback for old versions of python
289            self.stubs.Set(ssl, 'wrap_socket', fake_wrap_socket)
290        self.assertRaises(
291            websockifyserver.WebSockifyServer.EClose, server.do_handshake,
292            sock, '127.0.0.1')
293
294    def test_do_handshake_ssl_sets_ciphers(self):
295        test_ciphers = 'TEST-CIPHERS-1:TEST-CIPHER-2'
296
297        class FakeHandler(object):
298            def __init__(self, *args, **kwargs):
299                pass
300
301        server = self._get_server(handler_class=FakeHandler, daemon=True,
302                                  idle_timeout=1, ssl_ciphers=test_ciphers)
303        sock = FakeSocket("\x16some ssl data")
304
305        def fake_select(rlist, wlist, xlist, timeout=None):
306            return ([sock], [], [])
307
308        class fake_create_default_context():
309            CIPHERS = ''
310            def __init__(self, purpose):
311                self.verify_mode = None
312                self.options = 0
313            def load_cert_chain(self, certfile, keyfile, password):
314                pass
315            def set_default_verify_paths(self):
316                pass
317            def load_verify_locations(self, cafile):
318                pass
319            def wrap_socket(self, *args, **kwargs):
320                pass
321            def set_ciphers(self, ciphers_to_set):
322                fake_create_default_context.CIPHERS = ciphers_to_set
323
324        self.stubs.Set(select, 'select', fake_select)
325        if (hasattr(ssl, 'create_default_context')):
326            # for recent versions of python
327            self.stubs.Set(ssl, 'create_default_context', fake_create_default_context)
328            server.do_handshake(sock, '127.0.0.1')
329            self.assertEqual(fake_create_default_context.CIPHERS, test_ciphers)
330        else:
331            # for fallback for old versions of python
332            # not supperted, nothing to test
333            pass
334
335    def test_do_handshake_ssl_sets_opions(self):
336        test_options = 0xCAFEBEEF
337
338        class FakeHandler(object):
339            def __init__(self, *args, **kwargs):
340                pass
341
342        server = self._get_server(handler_class=FakeHandler, daemon=True,
343                                  idle_timeout=1, ssl_options=test_options)
344        sock = FakeSocket("\x16some ssl data")
345
346        def fake_select(rlist, wlist, xlist, timeout=None):
347            return ([sock], [], [])
348
349        class fake_create_default_context(object):
350            OPTIONS = 0
351            def __init__(self, purpose):
352                self.verify_mode = None
353                self._options = 0
354            def load_cert_chain(self, certfile, keyfile, password):
355                pass
356            def set_default_verify_paths(self):
357                pass
358            def load_verify_locations(self, cafile):
359                pass
360            def wrap_socket(self, *args, **kwargs):
361                pass
362            def get_options(self):
363                return self._options
364            def set_options(self, val):
365                fake_create_default_context.OPTIONS = val
366            options = property(get_options, set_options)
367
368        self.stubs.Set(select, 'select', fake_select)
369        if (hasattr(ssl, 'create_default_context')):
370            # for recent versions of python
371            self.stubs.Set(ssl, 'create_default_context', fake_create_default_context)
372            server.do_handshake(sock, '127.0.0.1')
373            self.assertEqual(fake_create_default_context.OPTIONS, test_options)
374        else:
375            # for fallback for old versions of python
376            # not supperted, nothing to test
377            pass
378
379    def test_fallback_sigchld_handler(self):
380        # TODO(directxman12): implement this
381        pass
382
383    def test_start_server_error(self):
384        server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1)
385        sock = server.socket('localhost')
386
387        def fake_select(rlist, wlist, xlist, timeout=None):
388            raise Exception("fake error")
389
390        self.stubs.Set(websockifyserver.WebSockifyServer, 'socket',
391                       lambda *args, **kwargs: sock)
392        self.stubs.Set(websockifyserver.WebSockifyServer, 'daemonize',
393                       lambda *args, **kwargs: None)
394        self.stubs.Set(select, 'select', fake_select)
395        server.start_server()
396
397    def test_start_server_keyboardinterrupt(self):
398        server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
399        sock = server.socket('localhost')
400
401        def fake_select(rlist, wlist, xlist, timeout=None):
402            raise KeyboardInterrupt
403
404        self.stubs.Set(websockifyserver.WebSockifyServer, 'socket',
405                       lambda *args, **kwargs: sock)
406        self.stubs.Set(websockifyserver.WebSockifyServer, 'daemonize',
407                       lambda *args, **kwargs: None)
408        self.stubs.Set(select, 'select', fake_select)
409        server.start_server()
410
411    def test_start_server_systemexit(self):
412        server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
413        sock = server.socket('localhost')
414
415        def fake_select(rlist, wlist, xlist, timeout=None):
416            sys.exit()
417
418        self.stubs.Set(websockifyserver.WebSockifyServer, 'socket',
419                       lambda *args, **kwargs: sock)
420        self.stubs.Set(websockifyserver.WebSockifyServer, 'daemonize',
421                       lambda *args, **kwargs: None)
422        self.stubs.Set(select, 'select', fake_select)
423        server.start_server()
424
425    def test_socket_set_keepalive_options(self):
426        keepcnt = 12
427        keepidle = 34
428        keepintvl = 56
429
430        server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
431        sock = server.socket('localhost',
432                             tcp_keepcnt=keepcnt,
433                             tcp_keepidle=keepidle,
434                             tcp_keepintvl=keepintvl)
435
436        if hasattr(socket, 'TCP_KEEPCNT'):
437            self.assertEqual(sock.getsockopt(socket.SOL_TCP,
438                                             socket.TCP_KEEPCNT), keepcnt)
439        self.assertEqual(sock.getsockopt(socket.SOL_TCP,
440                                         socket.TCP_KEEPIDLE), keepidle)
441        self.assertEqual(sock.getsockopt(socket.SOL_TCP,
442                                         socket.TCP_KEEPINTVL), keepintvl)
443
444        sock = server.socket('localhost',
445                             tcp_keepalive=False,
446                             tcp_keepcnt=keepcnt,
447                             tcp_keepidle=keepidle,
448                             tcp_keepintvl=keepintvl)
449
450        if hasattr(socket, 'TCP_KEEPCNT'):
451            self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
452                                                socket.TCP_KEEPCNT), keepcnt)
453        self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
454                                            socket.TCP_KEEPIDLE), keepidle)
455        self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
456                                            socket.TCP_KEEPINTVL), keepintvl)
457