1# vim: tabstop=4 shiftwidth=4 softtabstop=4 2 3# Copyright(c)2013 NTT corp. All Rights Reserved. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); you may 6# not use this file except in compliance with the License. You may obtain 7# a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 14# License for the specific language governing permissions and limitations 15# under the License. 16 17""" Unit tests for websockifyserver """ 18import errno 19import os 20import logging 21import select 22import shutil 23import socket 24import ssl 25from mox3 import stubout 26import sys 27import tempfile 28import unittest 29import socket 30import signal 31from websockify import websockifyserver 32 33try: 34 from BaseHTTPServer import BaseHTTPRequestHandler 35except ImportError: 36 from http.server import BaseHTTPRequestHandler 37 38try: 39 from StringIO import StringIO 40 BytesIO = StringIO 41except ImportError: 42 from io import StringIO 43 from io import BytesIO 44 45 46 47 48def raise_oserror(*args, **kwargs): 49 raise OSError('fake error') 50 51 52class FakeSocket(object): 53 def __init__(self, data=''): 54 if isinstance(data, bytes): 55 self._data = data 56 else: 57 self._data = data.encode('latin_1') 58 59 def recv(self, amt, flags=None): 60 res = self._data[0:amt] 61 if not (flags & socket.MSG_PEEK): 62 self._data = self._data[amt:] 63 64 return res 65 66 def makefile(self, mode='r', buffsize=None): 67 if 'b' in mode: 68 return BytesIO(self._data) 69 else: 70 return StringIO(self._data.decode('latin_1')) 71 72 73class WebSockifyRequestHandlerTestCase(unittest.TestCase): 74 def setUp(self): 75 super(WebSockifyRequestHandlerTestCase, self).setUp() 76 self.stubs = stubout.StubOutForTesting() 77 self.tmpdir = tempfile.mkdtemp('-websockify-tests') 78 # Mock this out cause it screws tests up 79 self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None) 80 self.stubs.Set(BaseHTTPRequestHandler, 'send_response', 81 lambda *args, **kwargs: None) 82 83 def fake_send_error(self, code, message=None, explain=None): 84 self.last_code = code 85 86 self.stubs.Set(BaseHTTPRequestHandler, 'send_error', 87 fake_send_error) 88 89 def tearDown(self): 90 """Called automatically after each test.""" 91 self.stubs.UnsetAll() 92 os.rmdir(self.tmpdir) 93 super(WebSockifyRequestHandlerTestCase, self).tearDown() 94 95 def _get_server(self, handler_class=websockifyserver.WebSockifyRequestHandler, 96 **kwargs): 97 web = kwargs.pop('web', self.tmpdir) 98 return websockifyserver.WebSockifyServer( 99 handler_class, listen_host='localhost', 100 listen_port=80, key=self.tmpdir, web=web, 101 record=self.tmpdir, daemon=False, ssl_only=0, idle_timeout=1, 102 **kwargs) 103 104 def test_normal_get_with_only_upgrade_returns_error(self): 105 server = self._get_server(web=None) 106 handler = websockifyserver.WebSockifyRequestHandler( 107 FakeSocket('GET /tmp.txt HTTP/1.1'), '127.0.0.1', server) 108 109 def fake_send_response(self, code, message=None): 110 self.last_code = code 111 112 self.stubs.Set(BaseHTTPRequestHandler, 'send_response', 113 fake_send_response) 114 115 handler.do_GET() 116 self.assertEqual(handler.last_code, 405) 117 118 def test_list_dir_with_file_only_returns_error(self): 119 server = self._get_server(file_only=True) 120 handler = websockifyserver.WebSockifyRequestHandler( 121 FakeSocket('GET / HTTP/1.1'), '127.0.0.1', server) 122 123 def fake_send_response(self, code, message=None): 124 self.last_code = code 125 126 self.stubs.Set(BaseHTTPRequestHandler, 'send_response', 127 fake_send_response) 128 129 handler.path = '/' 130 handler.do_GET() 131 self.assertEqual(handler.last_code, 404) 132 133 134class WebSockifyServerTestCase(unittest.TestCase): 135 def setUp(self): 136 super(WebSockifyServerTestCase, self).setUp() 137 self.stubs = stubout.StubOutForTesting() 138 self.tmpdir = tempfile.mkdtemp('-websockify-tests') 139 # Mock this out cause it screws tests up 140 self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None) 141 142 def tearDown(self): 143 """Called automatically after each test.""" 144 self.stubs.UnsetAll() 145 os.rmdir(self.tmpdir) 146 super(WebSockifyServerTestCase, self).tearDown() 147 148 def _get_server(self, handler_class=websockifyserver.WebSockifyRequestHandler, 149 **kwargs): 150 return websockifyserver.WebSockifyServer( 151 handler_class, listen_host='localhost', 152 listen_port=80, key=self.tmpdir, web=self.tmpdir, 153 record=self.tmpdir, **kwargs) 154 155 def test_daemonize_raises_error_while_closing_fds(self): 156 server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) 157 self.stubs.Set(os, 'fork', lambda *args: 0) 158 self.stubs.Set(signal, 'signal', lambda *args: None) 159 self.stubs.Set(os, 'setsid', lambda *args: None) 160 self.stubs.Set(os, 'close', raise_oserror) 161 self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./') 162 163 def test_daemonize_ignores_ebadf_error_while_closing_fds(self): 164 def raise_oserror_ebadf(fd): 165 raise OSError(errno.EBADF, 'fake error') 166 167 server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) 168 self.stubs.Set(os, 'fork', lambda *args: 0) 169 self.stubs.Set(os, 'setsid', lambda *args: None) 170 self.stubs.Set(signal, 'signal', lambda *args: None) 171 self.stubs.Set(os, 'close', raise_oserror_ebadf) 172 self.stubs.Set(os, 'open', raise_oserror) 173 self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./') 174 175 def test_handshake_fails_on_not_ready(self): 176 server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) 177 178 def fake_select(rlist, wlist, xlist, timeout=None): 179 return ([], [], []) 180 181 self.stubs.Set(select, 'select', fake_select) 182 self.assertRaises( 183 websockifyserver.WebSockifyServer.EClose, server.do_handshake, 184 FakeSocket(), '127.0.0.1') 185 186 def test_empty_handshake_fails(self): 187 server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) 188 189 sock = FakeSocket('') 190 191 def fake_select(rlist, wlist, xlist, timeout=None): 192 return ([sock], [], []) 193 194 self.stubs.Set(select, 'select', fake_select) 195 self.assertRaises( 196 websockifyserver.WebSockifyServer.EClose, server.do_handshake, 197 sock, '127.0.0.1') 198 199 def test_handshake_policy_request(self): 200 # TODO(directxman12): implement 201 pass 202 203 def test_handshake_ssl_only_without_ssl_raises_error(self): 204 server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) 205 206 sock = FakeSocket('some initial data') 207 208 def fake_select(rlist, wlist, xlist, timeout=None): 209 return ([sock], [], []) 210 211 self.stubs.Set(select, 'select', fake_select) 212 self.assertRaises( 213 websockifyserver.WebSockifyServer.EClose, server.do_handshake, 214 sock, '127.0.0.1') 215 216 def test_do_handshake_no_ssl(self): 217 class FakeHandler(object): 218 CALLED = False 219 def __init__(self, *args, **kwargs): 220 type(self).CALLED = True 221 222 FakeHandler.CALLED = False 223 224 server = self._get_server( 225 handler_class=FakeHandler, daemon=True, 226 ssl_only=0, idle_timeout=1) 227 228 sock = FakeSocket('some initial data') 229 230 def fake_select(rlist, wlist, xlist, timeout=None): 231 return ([sock], [], []) 232 233 self.stubs.Set(select, 'select', fake_select) 234 self.assertEqual(server.do_handshake(sock, '127.0.0.1'), sock) 235 self.assertTrue(FakeHandler.CALLED, True) 236 237 def test_do_handshake_ssl(self): 238 # TODO(directxman12): implement this 239 pass 240 241 def test_do_handshake_ssl_without_ssl_raises_error(self): 242 # TODO(directxman12): implement this 243 pass 244 245 def test_do_handshake_ssl_without_cert_raises_error(self): 246 server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1, 247 cert='afdsfasdafdsafdsafdsafdas') 248 249 sock = FakeSocket("\x16some ssl data") 250 251 def fake_select(rlist, wlist, xlist, timeout=None): 252 return ([sock], [], []) 253 254 self.stubs.Set(select, 'select', fake_select) 255 self.assertRaises( 256 websockifyserver.WebSockifyServer.EClose, server.do_handshake, 257 sock, '127.0.0.1') 258 259 def test_do_handshake_ssl_error_eof_raises_close_error(self): 260 server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) 261 262 sock = FakeSocket("\x16some ssl data") 263 264 def fake_select(rlist, wlist, xlist, timeout=None): 265 return ([sock], [], []) 266 267 def fake_wrap_socket(*args, **kwargs): 268 raise ssl.SSLError(ssl.SSL_ERROR_EOF) 269 270 class fake_create_default_context(): 271 def __init__(self, purpose): 272 self.verify_mode = None 273 self.options = 0 274 def load_cert_chain(self, certfile, keyfile, password): 275 pass 276 def set_default_verify_paths(self): 277 pass 278 def load_verify_locations(self, cafile): 279 pass 280 def wrap_socket(self, *args, **kwargs): 281 raise ssl.SSLError(ssl.SSL_ERROR_EOF) 282 283 self.stubs.Set(select, 'select', fake_select) 284 if (hasattr(ssl, 'create_default_context')): 285 # for recent versions of python 286 self.stubs.Set(ssl, 'create_default_context', fake_create_default_context) 287 else: 288 # for fallback for old versions of python 289 self.stubs.Set(ssl, 'wrap_socket', fake_wrap_socket) 290 self.assertRaises( 291 websockifyserver.WebSockifyServer.EClose, server.do_handshake, 292 sock, '127.0.0.1') 293 294 def test_do_handshake_ssl_sets_ciphers(self): 295 test_ciphers = 'TEST-CIPHERS-1:TEST-CIPHER-2' 296 297 class FakeHandler(object): 298 def __init__(self, *args, **kwargs): 299 pass 300 301 server = self._get_server(handler_class=FakeHandler, daemon=True, 302 idle_timeout=1, ssl_ciphers=test_ciphers) 303 sock = FakeSocket("\x16some ssl data") 304 305 def fake_select(rlist, wlist, xlist, timeout=None): 306 return ([sock], [], []) 307 308 class fake_create_default_context(): 309 CIPHERS = '' 310 def __init__(self, purpose): 311 self.verify_mode = None 312 self.options = 0 313 def load_cert_chain(self, certfile, keyfile, password): 314 pass 315 def set_default_verify_paths(self): 316 pass 317 def load_verify_locations(self, cafile): 318 pass 319 def wrap_socket(self, *args, **kwargs): 320 pass 321 def set_ciphers(self, ciphers_to_set): 322 fake_create_default_context.CIPHERS = ciphers_to_set 323 324 self.stubs.Set(select, 'select', fake_select) 325 if (hasattr(ssl, 'create_default_context')): 326 # for recent versions of python 327 self.stubs.Set(ssl, 'create_default_context', fake_create_default_context) 328 server.do_handshake(sock, '127.0.0.1') 329 self.assertEqual(fake_create_default_context.CIPHERS, test_ciphers) 330 else: 331 # for fallback for old versions of python 332 # not supperted, nothing to test 333 pass 334 335 def test_do_handshake_ssl_sets_opions(self): 336 test_options = 0xCAFEBEEF 337 338 class FakeHandler(object): 339 def __init__(self, *args, **kwargs): 340 pass 341 342 server = self._get_server(handler_class=FakeHandler, daemon=True, 343 idle_timeout=1, ssl_options=test_options) 344 sock = FakeSocket("\x16some ssl data") 345 346 def fake_select(rlist, wlist, xlist, timeout=None): 347 return ([sock], [], []) 348 349 class fake_create_default_context(object): 350 OPTIONS = 0 351 def __init__(self, purpose): 352 self.verify_mode = None 353 self._options = 0 354 def load_cert_chain(self, certfile, keyfile, password): 355 pass 356 def set_default_verify_paths(self): 357 pass 358 def load_verify_locations(self, cafile): 359 pass 360 def wrap_socket(self, *args, **kwargs): 361 pass 362 def get_options(self): 363 return self._options 364 def set_options(self, val): 365 fake_create_default_context.OPTIONS = val 366 options = property(get_options, set_options) 367 368 self.stubs.Set(select, 'select', fake_select) 369 if (hasattr(ssl, 'create_default_context')): 370 # for recent versions of python 371 self.stubs.Set(ssl, 'create_default_context', fake_create_default_context) 372 server.do_handshake(sock, '127.0.0.1') 373 self.assertEqual(fake_create_default_context.OPTIONS, test_options) 374 else: 375 # for fallback for old versions of python 376 # not supperted, nothing to test 377 pass 378 379 def test_fallback_sigchld_handler(self): 380 # TODO(directxman12): implement this 381 pass 382 383 def test_start_server_error(self): 384 server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1) 385 sock = server.socket('localhost') 386 387 def fake_select(rlist, wlist, xlist, timeout=None): 388 raise Exception("fake error") 389 390 self.stubs.Set(websockifyserver.WebSockifyServer, 'socket', 391 lambda *args, **kwargs: sock) 392 self.stubs.Set(websockifyserver.WebSockifyServer, 'daemonize', 393 lambda *args, **kwargs: None) 394 self.stubs.Set(select, 'select', fake_select) 395 server.start_server() 396 397 def test_start_server_keyboardinterrupt(self): 398 server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) 399 sock = server.socket('localhost') 400 401 def fake_select(rlist, wlist, xlist, timeout=None): 402 raise KeyboardInterrupt 403 404 self.stubs.Set(websockifyserver.WebSockifyServer, 'socket', 405 lambda *args, **kwargs: sock) 406 self.stubs.Set(websockifyserver.WebSockifyServer, 'daemonize', 407 lambda *args, **kwargs: None) 408 self.stubs.Set(select, 'select', fake_select) 409 server.start_server() 410 411 def test_start_server_systemexit(self): 412 server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) 413 sock = server.socket('localhost') 414 415 def fake_select(rlist, wlist, xlist, timeout=None): 416 sys.exit() 417 418 self.stubs.Set(websockifyserver.WebSockifyServer, 'socket', 419 lambda *args, **kwargs: sock) 420 self.stubs.Set(websockifyserver.WebSockifyServer, 'daemonize', 421 lambda *args, **kwargs: None) 422 self.stubs.Set(select, 'select', fake_select) 423 server.start_server() 424 425 def test_socket_set_keepalive_options(self): 426 keepcnt = 12 427 keepidle = 34 428 keepintvl = 56 429 430 server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) 431 sock = server.socket('localhost', 432 tcp_keepcnt=keepcnt, 433 tcp_keepidle=keepidle, 434 tcp_keepintvl=keepintvl) 435 436 if hasattr(socket, 'TCP_KEEPCNT'): 437 self.assertEqual(sock.getsockopt(socket.SOL_TCP, 438 socket.TCP_KEEPCNT), keepcnt) 439 self.assertEqual(sock.getsockopt(socket.SOL_TCP, 440 socket.TCP_KEEPIDLE), keepidle) 441 self.assertEqual(sock.getsockopt(socket.SOL_TCP, 442 socket.TCP_KEEPINTVL), keepintvl) 443 444 sock = server.socket('localhost', 445 tcp_keepalive=False, 446 tcp_keepcnt=keepcnt, 447 tcp_keepidle=keepidle, 448 tcp_keepintvl=keepintvl) 449 450 if hasattr(socket, 'TCP_KEEPCNT'): 451 self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, 452 socket.TCP_KEEPCNT), keepcnt) 453 self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, 454 socket.TCP_KEEPIDLE), keepidle) 455 self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, 456 socket.TCP_KEEPINTVL), keepintvl) 457