1# -*- coding: utf-8 -*-
2#
3"""
4
5"""
6
7"""
8websocket - WebSocket client library for Python
9
10Copyright (C) 2010 Hiroki Ohtani(liris)
11
12    This library is free software; you can redistribute it and/or
13    modify it under the terms of the GNU Lesser General Public
14    License as published by the Free Software Foundation; either
15    version 2.1 of the License, or (at your option) any later version.
16
17    This library is distributed in the hope that it will be useful,
18    but WITHOUT ANY WARRANTY; without even the implied warranty of
19    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20    Lesser General Public License for more details.
21
22    You should have received a copy of the GNU Lesser General Public
23    License along with this library; if not, write to the Free Software
24    Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
25
26"""
27
28import sys
29sys.path[0:0] = [""]
30import os
31import os.path
32import socket
33import websocket as ws
34from websocket._handshake import _create_sec_websocket_key, \
35    _validate as _validate_header
36from websocket._http import read_headers
37from websocket._utils import validate_utf8
38from base64 import decodebytes as base64decode
39
40import unittest
41
42try:
43    import ssl
44    from ssl import SSLError
45except ImportError:
46    # dummy class of SSLError for ssl none-support environment.
47    class SSLError(Exception):
48        pass
49
50# Skip test to access the internet.
51TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1'
52TRACEABLE = True
53
54
55def create_mask_key(_):
56    return "abcd"
57
58
59class SockMock(object):
60    def __init__(self):
61        self.data = []
62        self.sent = []
63
64    def add_packet(self, data):
65        self.data.append(data)
66
67    def gettimeout(self):
68        return None
69
70    def recv(self, bufsize):
71        if self.data:
72            e = self.data.pop(0)
73            if isinstance(e, Exception):
74                raise e
75            if len(e) > bufsize:
76                self.data.insert(0, e[bufsize:])
77            return e[:bufsize]
78
79    def send(self, data):
80        self.sent.append(data)
81        return len(data)
82
83    def close(self):
84        pass
85
86
87class HeaderSockMock(SockMock):
88
89    def __init__(self, fname):
90        SockMock.__init__(self)
91        path = os.path.join(os.path.dirname(__file__), fname)
92        with open(path, "rb") as f:
93            self.add_packet(f.read())
94
95
96class WebSocketTest(unittest.TestCase):
97    def setUp(self):
98        ws.enableTrace(TRACEABLE)
99
100    def tearDown(self):
101        pass
102
103    def testDefaultTimeout(self):
104        self.assertEqual(ws.getdefaulttimeout(), None)
105        ws.setdefaulttimeout(10)
106        self.assertEqual(ws.getdefaulttimeout(), 10)
107        ws.setdefaulttimeout(None)
108
109    def testWSKey(self):
110        key = _create_sec_websocket_key()
111        self.assertTrue(key != 24)
112        self.assertTrue(str("¥n") not in key)
113
114    def testNonce(self):
115        """ WebSocket key should be a random 16-byte nonce.
116        """
117        key = _create_sec_websocket_key()
118        nonce = base64decode(key.encode("utf-8"))
119        self.assertEqual(16, len(nonce))
120
121    def testWsUtils(self):
122        key = "c6b8hTg4EeGb2gQMztV1/g=="
123        required_header = {
124            "upgrade": "websocket",
125            "connection": "upgrade",
126            "sec-websocket-accept": "Kxep+hNu9n51529fGidYu7a3wO0="}
127        self.assertEqual(_validate_header(required_header, key, None), (True, None))
128
129        header = required_header.copy()
130        header["upgrade"] = "http"
131        self.assertEqual(_validate_header(header, key, None), (False, None))
132        del header["upgrade"]
133        self.assertEqual(_validate_header(header, key, None), (False, None))
134
135        header = required_header.copy()
136        header["connection"] = "something"
137        self.assertEqual(_validate_header(header, key, None), (False, None))
138        del header["connection"]
139        self.assertEqual(_validate_header(header, key, None), (False, None))
140
141        header = required_header.copy()
142        header["sec-websocket-accept"] = "something"
143        self.assertEqual(_validate_header(header, key, None), (False, None))
144        del header["sec-websocket-accept"]
145        self.assertEqual(_validate_header(header, key, None), (False, None))
146
147        header = required_header.copy()
148        header["sec-websocket-protocol"] = "sub1"
149        self.assertEqual(_validate_header(header, key, ["sub1", "sub2"]), (True, "sub1"))
150        # This case will print out a logging error using the error() function, but that is expected
151        self.assertEqual(_validate_header(header, key, ["sub2", "sub3"]), (False, None))
152
153        header = required_header.copy()
154        header["sec-websocket-protocol"] = "sUb1"
155        self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (True, "sub1"))
156
157        header = required_header.copy()
158        # This case will print out a logging error using the error() function, but that is expected
159        self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (False, None))
160
161    def testReadHeader(self):
162        status, header, status_message = read_headers(HeaderSockMock("data/header01.txt"))
163        self.assertEqual(status, 101)
164        self.assertEqual(header["connection"], "Upgrade")
165
166        status, header, status_message = read_headers(HeaderSockMock("data/header03.txt"))
167        self.assertEqual(status, 101)
168        self.assertEqual(header["connection"], "Upgrade, Keep-Alive")
169
170        HeaderSockMock("data/header02.txt")
171        self.assertRaises(ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt"))
172
173    def testSend(self):
174        # TODO: add longer frame data
175        sock = ws.WebSocket()
176        sock.set_mask_key(create_mask_key)
177        s = sock.sock = HeaderSockMock("data/header01.txt")
178        sock.send("Hello")
179        self.assertEqual(s.sent[0], b'\x81\x85abcd)\x07\x0f\x08\x0e')
180
181        sock.send("こんにちは")
182        self.assertEqual(s.sent[1], b'\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc')
183
184#        sock.send("x" * 5000)
185#        self.assertEqual(s.sent[1], b'\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")
186
187        self.assertEqual(sock.send_binary(b'1111111111101'), 19)
188
189    def testRecv(self):
190        # TODO: add longer frame data
191        sock = ws.WebSocket()
192        s = sock.sock = SockMock()
193        something = b'\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc'
194        s.add_packet(something)
195        data = sock.recv()
196        self.assertEqual(data, "こんにちは")
197
198        s.add_packet(b'\x81\x85abcd)\x07\x0f\x08\x0e')
199        data = sock.recv()
200        self.assertEqual(data, "Hello")
201
202    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
203    def testIter(self):
204        count = 2
205        for _ in ws.create_connection('wss://stream.meetup.com/2/rsvps'):
206            count -= 1
207            if count == 0:
208                break
209
210    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
211    def testNext(self):
212        sock = ws.create_connection('wss://stream.meetup.com/2/rsvps')
213        self.assertEqual(str, type(next(sock)))
214
215    def testInternalRecvStrict(self):
216        sock = ws.WebSocket()
217        s = sock.sock = SockMock()
218        s.add_packet(b'foo')
219        s.add_packet(socket.timeout())
220        s.add_packet(b'bar')
221        # s.add_packet(SSLError("The read operation timed out"))
222        s.add_packet(b'baz')
223        with self.assertRaises(ws.WebSocketTimeoutException):
224            sock.frame_buffer.recv_strict(9)
225        #     with self.assertRaises(SSLError):
226        #         data = sock._recv_strict(9)
227        data = sock.frame_buffer.recv_strict(9)
228        self.assertEqual(data, b'foobarbaz')
229        with self.assertRaises(ws.WebSocketConnectionClosedException):
230            sock.frame_buffer.recv_strict(1)
231
232    def testRecvTimeout(self):
233        sock = ws.WebSocket()
234        s = sock.sock = SockMock()
235        s.add_packet(b'\x81')
236        s.add_packet(socket.timeout())
237        s.add_packet(b'\x8dabcd\x29\x07\x0f\x08\x0e')
238        s.add_packet(socket.timeout())
239        s.add_packet(b'\x4e\x43\x33\x0e\x10\x0f\x00\x40')
240        with self.assertRaises(ws.WebSocketTimeoutException):
241            sock.recv()
242        with self.assertRaises(ws.WebSocketTimeoutException):
243            sock.recv()
244        data = sock.recv()
245        self.assertEqual(data, "Hello, World!")
246        with self.assertRaises(ws.WebSocketConnectionClosedException):
247            sock.recv()
248
249    def testRecvWithSimpleFragmentation(self):
250        sock = ws.WebSocket()
251        s = sock.sock = SockMock()
252        # OPCODE=TEXT, FIN=0, MSG="Brevity is "
253        s.add_packet(b'\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C')
254        # OPCODE=CONT, FIN=1, MSG="the soul of wit"
255        s.add_packet(b'\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17')
256        data = sock.recv()
257        self.assertEqual(data, "Brevity is the soul of wit")
258        with self.assertRaises(ws.WebSocketConnectionClosedException):
259            sock.recv()
260
261    def testRecvWithFireEventOfFragmentation(self):
262        sock = ws.WebSocket(fire_cont_frame=True)
263        s = sock.sock = SockMock()
264        # OPCODE=TEXT, FIN=0, MSG="Brevity is "
265        s.add_packet(b'\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C')
266        # OPCODE=CONT, FIN=0, MSG="Brevity is "
267        s.add_packet(b'\x00\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C')
268        # OPCODE=CONT, FIN=1, MSG="the soul of wit"
269        s.add_packet(b'\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17')
270
271        _, data = sock.recv_data()
272        self.assertEqual(data, b'Brevity is ')
273        _, data = sock.recv_data()
274        self.assertEqual(data, b'Brevity is ')
275        _, data = sock.recv_data()
276        self.assertEqual(data, b'the soul of wit')
277
278        # OPCODE=CONT, FIN=0, MSG="Brevity is "
279        s.add_packet(b'\x80\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C')
280
281        with self.assertRaises(ws.WebSocketException):
282            sock.recv_data()
283
284        with self.assertRaises(ws.WebSocketConnectionClosedException):
285            sock.recv()
286
287    def testClose(self):
288        sock = ws.WebSocket()
289        sock.connected = True
290        self.assertRaises(ws._exceptions.WebSocketConnectionClosedException, sock.close)
291
292        sock = ws.WebSocket()
293        s = sock.sock = SockMock()
294        sock.connected = True
295        s.add_packet(b'\x88\x80\x17\x98p\x84')
296        sock.recv()
297        self.assertEqual(sock.connected, False)
298
299    def testRecvContFragmentation(self):
300        sock = ws.WebSocket()
301        s = sock.sock = SockMock()
302        # OPCODE=CONT, FIN=1, MSG="the soul of wit"
303        s.add_packet(b'\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17')
304        self.assertRaises(ws.WebSocketException, sock.recv)
305
306    def testRecvWithProlongedFragmentation(self):
307        sock = ws.WebSocket()
308        s = sock.sock = SockMock()
309        # OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, "
310        s.add_packet(b'\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15\rC\x10\t\x07C\x06\x13\x07\x02\x07\tNC')
311        # OPCODE=CONT, FIN=0, MSG="dear friends, "
312        s.add_packet(b'\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07\x17MB')
313        # OPCODE=CONT, FIN=1, MSG="once more"
314        s.add_packet(b'\x80\x89abcd\x0e\x0c\x00\x01A\x0f\x0c\x16\x04')
315        data = sock.recv()
316        self.assertEqual(
317            data,
318            "Once more unto the breach, dear friends, once more")
319        with self.assertRaises(ws.WebSocketConnectionClosedException):
320            sock.recv()
321
322    def testRecvWithFragmentationAndControlFrame(self):
323        sock = ws.WebSocket()
324        sock.set_mask_key(create_mask_key)
325        s = sock.sock = SockMock()
326        # OPCODE=TEXT, FIN=0, MSG="Too much "
327        s.add_packet(b'\x01\x89abcd5\r\x0cD\x0c\x17\x00\x0cA')
328        # OPCODE=PING, FIN=1, MSG="Please PONG this"
329        s.add_packet(b'\x89\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17')
330        # OPCODE=CONT, FIN=1, MSG="of a good thing"
331        s.add_packet(b'\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c\x08\x0c\x04')
332        data = sock.recv()
333        self.assertEqual(data, "Too much of a good thing")
334        with self.assertRaises(ws.WebSocketConnectionClosedException):
335            sock.recv()
336        self.assertEqual(
337            s.sent[0],
338            b'\x8a\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17')
339
340    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
341    def testWebSocket(self):
342        s = ws.create_connection("ws://echo.websocket.org/")
343        self.assertNotEqual(s, None)
344        s.send("Hello, World")
345        result = s.recv()
346        self.assertEqual(result, "Hello, World")
347
348        s.send("こにゃにゃちは、世界")
349        result = s.recv()
350        self.assertEqual(result, "こにゃにゃちは、世界")
351        self.assertRaises(ValueError, s.send_close, -1, "")
352        s.close()
353
354    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
355    def testPingPong(self):
356        s = ws.create_connection("ws://echo.websocket.org/")
357        self.assertNotEqual(s, None)
358        s.ping("Hello")
359        s.pong("Hi")
360        s.close()
361
362    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
363    def testSecureWebSocket(self):
364        import ssl
365        s = ws.create_connection("wss://api.bitfinex.com/ws/2")
366        self.assertNotEqual(s, None)
367        self.assertTrue(isinstance(s.sock, ssl.SSLSocket))
368        self.assertEqual(s.getstatus(), 101)
369        self.assertNotEqual(s.getheaders(), None)
370        s.settimeout(10)
371        self.assertEqual(s.gettimeout(), 10)
372        self.assertEqual(s.getsubprotocol(), None)
373        s.abort()
374
375    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
376    def testWebSocketWithCustomHeader(self):
377        s = ws.create_connection("ws://echo.websocket.org/",
378                                 headers={"User-Agent": "PythonWebsocketClient"})
379        self.assertNotEqual(s, None)
380        s.send("Hello, World")
381        result = s.recv()
382        self.assertEqual(result, "Hello, World")
383        self.assertRaises(ValueError, s.close, -1, "")
384        s.close()
385
386    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
387    def testAfterClose(self):
388        s = ws.create_connection("ws://echo.websocket.org/")
389        self.assertNotEqual(s, None)
390        s.close()
391        self.assertRaises(ws.WebSocketConnectionClosedException, s.send, "Hello")
392        self.assertRaises(ws.WebSocketConnectionClosedException, s.recv)
393
394
395class SockOptTest(unittest.TestCase):
396    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
397    def testSockOpt(self):
398        sockopt = ((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),)
399        s = ws.create_connection("ws://echo.websocket.org", sockopt=sockopt)
400        self.assertNotEqual(s.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY), 0)
401        s.close()
402
403
404class UtilsTest(unittest.TestCase):
405    def testUtf8Validator(self):
406        state = validate_utf8(b'\xf0\x90\x80\x80')
407        self.assertEqual(state, True)
408        state = validate_utf8(b'\xce\xba\xe1\xbd\xb9\xcf\x83\xce\xbc\xce\xb5\xed\xa0\x80edited')
409        self.assertEqual(state, False)
410        state = validate_utf8(b'')
411        self.assertEqual(state, True)
412
413
414class HandshakeTest(unittest.TestCase):
415    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
416    def test_http_SSL(self):
417        websock1 = ws.WebSocket(sslopt={"cert_chain": ssl.get_default_verify_paths().capath})
418        self.assertRaises(ValueError,
419                          websock1.connect, "wss://api.bitfinex.com/ws/2")
420        websock2 = ws.WebSocket(sslopt={"certfile": "myNonexistentCertFile"})
421        self.assertRaises(FileNotFoundError,
422                          websock2.connect, "wss://api.bitfinex.com/ws/2")
423
424    @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
425    def testManualHeaders(self):
426        websock3 = ws.WebSocket(sslopt={"cert_reqs": ssl.CERT_NONE,
427                                        "ca_certs": ssl.get_default_verify_paths().capath,
428                                        "ca_cert_path": ssl.get_default_verify_paths().openssl_cafile})
429        self.assertRaises(ws._exceptions.WebSocketBadStatusException,
430                          websock3.connect, "wss://api.bitfinex.com/ws/2", cookie="chocolate",
431                          origin="testing_websockets.com",
432                          host="echo.websocket.org/websocket-client-test",
433                          subprotocols=["testproto"],
434                          connection="Upgrade",
435                          header={"CustomHeader1":"123",
436                                  "Cookie":"TestValue",
437                                  "Sec-WebSocket-Key":"k9kFAUWNAMmf5OEMfTlOEA==",
438                                  "Sec-WebSocket-Protocol":"newprotocol"})
439
440    def testIPv6(self):
441        websock2 = ws.WebSocket()
442        self.assertRaises(ValueError, websock2.connect, "2001:4860:4860::8888")
443
444    def testBadURLs(self):
445        websock3 = ws.WebSocket()
446        self.assertRaises(ValueError, websock3.connect, "ws//example.com")
447        self.assertRaises(ws.WebSocketAddressException, websock3.connect, "ws://example")
448        self.assertRaises(ValueError, websock3.connect, "example.com")
449
450
451if __name__ == "__main__":
452    unittest.main()
453