1# -*- coding: utf-8 -*- 2# 3""" 4 5""" 6 7""" 8websocket - WebSocket client library for Python 9 10Copyright (C) 2010 Hiroki Ohtani(liris) 11 12 This library is free software; you can redistribute it and/or 13 modify it under the terms of the GNU Lesser General Public 14 License as published by the Free Software Foundation; either 15 version 2.1 of the License, or (at your option) any later version. 16 17 This library is distributed in the hope that it will be useful, 18 but WITHOUT ANY WARRANTY; without even the implied warranty of 19 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 20 Lesser General Public License for more details. 21 22 You should have received a copy of the GNU Lesser General Public 23 License along with this library; if not, write to the Free Software 24 Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 25 26""" 27 28import sys 29sys.path[0:0] = [""] 30import os 31import os.path 32import socket 33import websocket as ws 34from websocket._handshake import _create_sec_websocket_key, \ 35 _validate as _validate_header 36from websocket._http import read_headers 37from websocket._utils import validate_utf8 38from base64 import decodebytes as base64decode 39 40import unittest 41 42try: 43 import ssl 44 from ssl import SSLError 45except ImportError: 46 # dummy class of SSLError for ssl none-support environment. 47 class SSLError(Exception): 48 pass 49 50# Skip test to access the internet. 51TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1' 52TRACEABLE = True 53 54 55def create_mask_key(_): 56 return "abcd" 57 58 59class SockMock(object): 60 def __init__(self): 61 self.data = [] 62 self.sent = [] 63 64 def add_packet(self, data): 65 self.data.append(data) 66 67 def gettimeout(self): 68 return None 69 70 def recv(self, bufsize): 71 if self.data: 72 e = self.data.pop(0) 73 if isinstance(e, Exception): 74 raise e 75 if len(e) > bufsize: 76 self.data.insert(0, e[bufsize:]) 77 return e[:bufsize] 78 79 def send(self, data): 80 self.sent.append(data) 81 return len(data) 82 83 def close(self): 84 pass 85 86 87class HeaderSockMock(SockMock): 88 89 def __init__(self, fname): 90 SockMock.__init__(self) 91 path = os.path.join(os.path.dirname(__file__), fname) 92 with open(path, "rb") as f: 93 self.add_packet(f.read()) 94 95 96class WebSocketTest(unittest.TestCase): 97 def setUp(self): 98 ws.enableTrace(TRACEABLE) 99 100 def tearDown(self): 101 pass 102 103 def testDefaultTimeout(self): 104 self.assertEqual(ws.getdefaulttimeout(), None) 105 ws.setdefaulttimeout(10) 106 self.assertEqual(ws.getdefaulttimeout(), 10) 107 ws.setdefaulttimeout(None) 108 109 def testWSKey(self): 110 key = _create_sec_websocket_key() 111 self.assertTrue(key != 24) 112 self.assertTrue(str("¥n") not in key) 113 114 def testNonce(self): 115 """ WebSocket key should be a random 16-byte nonce. 116 """ 117 key = _create_sec_websocket_key() 118 nonce = base64decode(key.encode("utf-8")) 119 self.assertEqual(16, len(nonce)) 120 121 def testWsUtils(self): 122 key = "c6b8hTg4EeGb2gQMztV1/g==" 123 required_header = { 124 "upgrade": "websocket", 125 "connection": "upgrade", 126 "sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0="} 127 self.assertEqual(_validate_header(required_header, key, None), (True, None)) 128 129 header = required_header.copy() 130 header["upgrade"] = "http" 131 self.assertEqual(_validate_header(header, key, None), (False, None)) 132 del header["upgrade"] 133 self.assertEqual(_validate_header(header, key, None), (False, None)) 134 135 header = required_header.copy() 136 header["connection"] = "something" 137 self.assertEqual(_validate_header(header, key, None), (False, None)) 138 del header["connection"] 139 self.assertEqual(_validate_header(header, key, None), (False, None)) 140 141 header = required_header.copy() 142 header["sec-websocket-accept"] = "something" 143 self.assertEqual(_validate_header(header, key, None), (False, None)) 144 del header["sec-websocket-accept"] 145 self.assertEqual(_validate_header(header, key, None), (False, None)) 146 147 header = required_header.copy() 148 header["sec-websocket-protocol"] = "sub1" 149 self.assertEqual(_validate_header(header, key, ["sub1", "sub2"]), (True, "sub1")) 150 # This case will print out a logging error using the error() function, but that is expected 151 self.assertEqual(_validate_header(header, key, ["sub2", "sub3"]), (False, None)) 152 153 header = required_header.copy() 154 header["sec-websocket-protocol"] = "sUb1" 155 self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (True, "sub1")) 156 157 header = required_header.copy() 158 # This case will print out a logging error using the error() function, but that is expected 159 self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (False, None)) 160 161 def testReadHeader(self): 162 status, header, status_message = read_headers(HeaderSockMock("data/header01.txt")) 163 self.assertEqual(status, 101) 164 self.assertEqual(header["connection"], "Upgrade") 165 166 status, header, status_message = read_headers(HeaderSockMock("data/header03.txt")) 167 self.assertEqual(status, 101) 168 self.assertEqual(header["connection"], "Upgrade, Keep-Alive") 169 170 HeaderSockMock("data/header02.txt") 171 self.assertRaises(ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt")) 172 173 def testSend(self): 174 # TODO: add longer frame data 175 sock = ws.WebSocket() 176 sock.set_mask_key(create_mask_key) 177 s = sock.sock = HeaderSockMock("data/header01.txt") 178 sock.send("Hello") 179 self.assertEqual(s.sent[0], b'\x81\x85abcd)\x07\x0f\x08\x0e') 180 181 sock.send("こんにちは") 182 self.assertEqual(s.sent[1], b'\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc') 183 184# sock.send("x" * 5000) 185# self.assertEqual(s.sent[1], b'\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc") 186 187 self.assertEqual(sock.send_binary(b'1111111111101'), 19) 188 189 def testRecv(self): 190 # TODO: add longer frame data 191 sock = ws.WebSocket() 192 s = sock.sock = SockMock() 193 something = b'\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc' 194 s.add_packet(something) 195 data = sock.recv() 196 self.assertEqual(data, "こんにちは") 197 198 s.add_packet(b'\x81\x85abcd)\x07\x0f\x08\x0e') 199 data = sock.recv() 200 self.assertEqual(data, "Hello") 201 202 @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") 203 def testIter(self): 204 count = 2 205 for _ in ws.create_connection('wss://stream.meetup.com/2/rsvps'): 206 count -= 1 207 if count == 0: 208 break 209 210 @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") 211 def testNext(self): 212 sock = ws.create_connection('wss://stream.meetup.com/2/rsvps') 213 self.assertEqual(str, type(next(sock))) 214 215 def testInternalRecvStrict(self): 216 sock = ws.WebSocket() 217 s = sock.sock = SockMock() 218 s.add_packet(b'foo') 219 s.add_packet(socket.timeout()) 220 s.add_packet(b'bar') 221 # s.add_packet(SSLError("The read operation timed out")) 222 s.add_packet(b'baz') 223 with self.assertRaises(ws.WebSocketTimeoutException): 224 sock.frame_buffer.recv_strict(9) 225 # with self.assertRaises(SSLError): 226 # data = sock._recv_strict(9) 227 data = sock.frame_buffer.recv_strict(9) 228 self.assertEqual(data, b'foobarbaz') 229 with self.assertRaises(ws.WebSocketConnectionClosedException): 230 sock.frame_buffer.recv_strict(1) 231 232 def testRecvTimeout(self): 233 sock = ws.WebSocket() 234 s = sock.sock = SockMock() 235 s.add_packet(b'\x81') 236 s.add_packet(socket.timeout()) 237 s.add_packet(b'\x8dabcd\x29\x07\x0f\x08\x0e') 238 s.add_packet(socket.timeout()) 239 s.add_packet(b'\x4e\x43\x33\x0e\x10\x0f\x00\x40') 240 with self.assertRaises(ws.WebSocketTimeoutException): 241 sock.recv() 242 with self.assertRaises(ws.WebSocketTimeoutException): 243 sock.recv() 244 data = sock.recv() 245 self.assertEqual(data, "Hello, World!") 246 with self.assertRaises(ws.WebSocketConnectionClosedException): 247 sock.recv() 248 249 def testRecvWithSimpleFragmentation(self): 250 sock = ws.WebSocket() 251 s = sock.sock = SockMock() 252 # OPCODE=TEXT, FIN=0, MSG="Brevity is " 253 s.add_packet(b'\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C') 254 # OPCODE=CONT, FIN=1, MSG="the soul of wit" 255 s.add_packet(b'\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17') 256 data = sock.recv() 257 self.assertEqual(data, "Brevity is the soul of wit") 258 with self.assertRaises(ws.WebSocketConnectionClosedException): 259 sock.recv() 260 261 def testRecvWithFireEventOfFragmentation(self): 262 sock = ws.WebSocket(fire_cont_frame=True) 263 s = sock.sock = SockMock() 264 # OPCODE=TEXT, FIN=0, MSG="Brevity is " 265 s.add_packet(b'\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C') 266 # OPCODE=CONT, FIN=0, MSG="Brevity is " 267 s.add_packet(b'\x00\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C') 268 # OPCODE=CONT, FIN=1, MSG="the soul of wit" 269 s.add_packet(b'\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17') 270 271 _, data = sock.recv_data() 272 self.assertEqual(data, b'Brevity is ') 273 _, data = sock.recv_data() 274 self.assertEqual(data, b'Brevity is ') 275 _, data = sock.recv_data() 276 self.assertEqual(data, b'the soul of wit') 277 278 # OPCODE=CONT, FIN=0, MSG="Brevity is " 279 s.add_packet(b'\x80\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C') 280 281 with self.assertRaises(ws.WebSocketException): 282 sock.recv_data() 283 284 with self.assertRaises(ws.WebSocketConnectionClosedException): 285 sock.recv() 286 287 def testClose(self): 288 sock = ws.WebSocket() 289 sock.connected = True 290 self.assertRaises(ws._exceptions.WebSocketConnectionClosedException, sock.close) 291 292 sock = ws.WebSocket() 293 s = sock.sock = SockMock() 294 sock.connected = True 295 s.add_packet(b'\x88\x80\x17\x98p\x84') 296 sock.recv() 297 self.assertEqual(sock.connected, False) 298 299 def testRecvContFragmentation(self): 300 sock = ws.WebSocket() 301 s = sock.sock = SockMock() 302 # OPCODE=CONT, FIN=1, MSG="the soul of wit" 303 s.add_packet(b'\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17') 304 self.assertRaises(ws.WebSocketException, sock.recv) 305 306 def testRecvWithProlongedFragmentation(self): 307 sock = ws.WebSocket() 308 s = sock.sock = SockMock() 309 # OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, " 310 s.add_packet(b'\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15\rC\x10\t\x07C\x06\x13\x07\x02\x07\tNC') 311 # OPCODE=CONT, FIN=0, MSG="dear friends, " 312 s.add_packet(b'\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07\x17MB') 313 # OPCODE=CONT, FIN=1, MSG="once more" 314 s.add_packet(b'\x80\x89abcd\x0e\x0c\x00\x01A\x0f\x0c\x16\x04') 315 data = sock.recv() 316 self.assertEqual( 317 data, 318 "Once more unto the breach, dear friends, once more") 319 with self.assertRaises(ws.WebSocketConnectionClosedException): 320 sock.recv() 321 322 def testRecvWithFragmentationAndControlFrame(self): 323 sock = ws.WebSocket() 324 sock.set_mask_key(create_mask_key) 325 s = sock.sock = SockMock() 326 # OPCODE=TEXT, FIN=0, MSG="Too much " 327 s.add_packet(b'\x01\x89abcd5\r\x0cD\x0c\x17\x00\x0cA') 328 # OPCODE=PING, FIN=1, MSG="Please PONG this" 329 s.add_packet(b'\x89\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17') 330 # OPCODE=CONT, FIN=1, MSG="of a good thing" 331 s.add_packet(b'\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c\x08\x0c\x04') 332 data = sock.recv() 333 self.assertEqual(data, "Too much of a good thing") 334 with self.assertRaises(ws.WebSocketConnectionClosedException): 335 sock.recv() 336 self.assertEqual( 337 s.sent[0], 338 b'\x8a\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17') 339 340 @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") 341 def testWebSocket(self): 342 s = ws.create_connection("ws://echo.websocket.org/") 343 self.assertNotEqual(s, None) 344 s.send("Hello, World") 345 result = s.recv() 346 self.assertEqual(result, "Hello, World") 347 348 s.send("こにゃにゃちは、世界") 349 result = s.recv() 350 self.assertEqual(result, "こにゃにゃちは、世界") 351 self.assertRaises(ValueError, s.send_close, -1, "") 352 s.close() 353 354 @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") 355 def testPingPong(self): 356 s = ws.create_connection("ws://echo.websocket.org/") 357 self.assertNotEqual(s, None) 358 s.ping("Hello") 359 s.pong("Hi") 360 s.close() 361 362 @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") 363 def testSecureWebSocket(self): 364 import ssl 365 s = ws.create_connection("wss://api.bitfinex.com/ws/2") 366 self.assertNotEqual(s, None) 367 self.assertTrue(isinstance(s.sock, ssl.SSLSocket)) 368 self.assertEqual(s.getstatus(), 101) 369 self.assertNotEqual(s.getheaders(), None) 370 s.settimeout(10) 371 self.assertEqual(s.gettimeout(), 10) 372 self.assertEqual(s.getsubprotocol(), None) 373 s.abort() 374 375 @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") 376 def testWebSocketWithCustomHeader(self): 377 s = ws.create_connection("ws://echo.websocket.org/", 378 headers={"User-Agent": "PythonWebsocketClient"}) 379 self.assertNotEqual(s, None) 380 s.send("Hello, World") 381 result = s.recv() 382 self.assertEqual(result, "Hello, World") 383 self.assertRaises(ValueError, s.close, -1, "") 384 s.close() 385 386 @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") 387 def testAfterClose(self): 388 s = ws.create_connection("ws://echo.websocket.org/") 389 self.assertNotEqual(s, None) 390 s.close() 391 self.assertRaises(ws.WebSocketConnectionClosedException, s.send, "Hello") 392 self.assertRaises(ws.WebSocketConnectionClosedException, s.recv) 393 394 395class SockOptTest(unittest.TestCase): 396 @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") 397 def testSockOpt(self): 398 sockopt = ((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),) 399 s = ws.create_connection("ws://echo.websocket.org", sockopt=sockopt) 400 self.assertNotEqual(s.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY), 0) 401 s.close() 402 403 404class UtilsTest(unittest.TestCase): 405 def testUtf8Validator(self): 406 state = validate_utf8(b'\xf0\x90\x80\x80') 407 self.assertEqual(state, True) 408 state = validate_utf8(b'\xce\xba\xe1\xbd\xb9\xcf\x83\xce\xbc\xce\xb5\xed\xa0\x80edited') 409 self.assertEqual(state, False) 410 state = validate_utf8(b'') 411 self.assertEqual(state, True) 412 413 414class HandshakeTest(unittest.TestCase): 415 @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") 416 def test_http_SSL(self): 417 websock1 = ws.WebSocket(sslopt={"cert_chain": ssl.get_default_verify_paths().capath}) 418 self.assertRaises(ValueError, 419 websock1.connect, "wss://api.bitfinex.com/ws/2") 420 websock2 = ws.WebSocket(sslopt={"certfile": "myNonexistentCertFile"}) 421 self.assertRaises(FileNotFoundError, 422 websock2.connect, "wss://api.bitfinex.com/ws/2") 423 424 @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") 425 def testManualHeaders(self): 426 websock3 = ws.WebSocket(sslopt={"cert_reqs": ssl.CERT_NONE, 427 "ca_certs": ssl.get_default_verify_paths().capath, 428 "ca_cert_path": ssl.get_default_verify_paths().openssl_cafile}) 429 self.assertRaises(ws._exceptions.WebSocketBadStatusException, 430 websock3.connect, "wss://api.bitfinex.com/ws/2", cookie="chocolate", 431 origin="testing_websockets.com", 432 host="echo.websocket.org/websocket-client-test", 433 subprotocols=["testproto"], 434 connection="Upgrade", 435 header={"CustomHeader1":"123", 436 "Cookie":"TestValue", 437 "Sec-WebSocket-Key":"k9kFAUWNAMmf5OEMfTlOEA==", 438 "Sec-WebSocket-Protocol":"newprotocol"}) 439 440 def testIPv6(self): 441 websock2 = ws.WebSocket() 442 self.assertRaises(ValueError, websock2.connect, "2001:4860:4860::8888") 443 444 def testBadURLs(self): 445 websock3 = ws.WebSocket() 446 self.assertRaises(ValueError, websock3.connect, "ws//example.com") 447 self.assertRaises(ws.WebSocketAddressException, websock3.connect, "ws://example") 448 self.assertRaises(ValueError, websock3.connect, "example.com") 449 450 451if __name__ == "__main__": 452 unittest.main() 453