1import errno
2import struct
3import re
4
5import eventlet
6from eventlet import event
7from eventlet import websocket
8from eventlet.green import httplib
9from eventlet.green import socket
10import six
11
12import tests.wsgi_test
13
14
15# demo app
16def handle(ws):
17    if ws.path == '/echo':
18        while True:
19            m = ws.wait()
20            if m is None:
21                break
22            ws.send(m)
23    elif ws.path == '/range':
24        for i in range(10):
25            ws.send("msg %d" % i)
26            eventlet.sleep(0.01)
27    elif ws.path == '/error':
28        # some random socket error that we shouldn't normally get
29        raise socket.error(errno.ENOTSOCK)
30    else:
31        ws.close()
32
33
34# Set a lower limit of DEFAULT_MAX_FRAME_LENGTH for testing, as
35# sending an 8MiB frame over the loopback interface can trigger a
36# timeout.
37TEST_MAX_FRAME_LENGTH = 50000
38wsapp = websocket.WebSocketWSGI(handle, max_frame_length=TEST_MAX_FRAME_LENGTH)
39
40
41class TestWebSocket(tests.wsgi_test._TestBase):
42    TEST_TIMEOUT = 5
43
44    def set_site(self):
45        self.site = wsapp
46
47    def test_incomplete_headers_13(self):
48        headers = dict(kv.split(': ') for kv in [
49            "Upgrade: websocket",
50            # NOTE: intentionally no connection header
51            "Host: %s:%s" % self.server_addr,
52            "Origin: http://%s:%s" % self.server_addr,
53            "Sec-WebSocket-Version: 13",
54        ])
55        http = httplib.HTTPConnection(*self.server_addr)
56        http.request("GET", "/echo", headers=headers)
57        resp = http.getresponse()
58
59        self.assertEqual(resp.status, 400)
60        self.assertEqual(resp.getheader('connection'), 'close')
61        self.assertEqual(resp.read(), b'')
62
63        # Now, miss off key
64        headers = dict(kv.split(': ') for kv in [
65            "Upgrade: websocket",
66            "Connection: Upgrade",
67            "Host: %s:%s" % self.server_addr,
68            "Origin: http://%s:%s" % self.server_addr,
69            "Sec-WebSocket-Version: 13",
70        ])
71        http = httplib.HTTPConnection(*self.server_addr)
72        http.request("GET", "/echo", headers=headers)
73        resp = http.getresponse()
74
75        self.assertEqual(resp.status, 400)
76        self.assertEqual(resp.getheader('connection'), 'close')
77        self.assertEqual(resp.read(), b'')
78
79        # No Upgrade now
80        headers = dict(kv.split(': ') for kv in [
81            "Connection: Upgrade",
82            "Host: %s:%s" % self.server_addr,
83            "Origin: http://%s:%s" % self.server_addr,
84            "Sec-WebSocket-Version: 13",
85        ])
86        http = httplib.HTTPConnection(*self.server_addr)
87        http.request("GET", "/echo", headers=headers)
88        resp = http.getresponse()
89
90        self.assertEqual(resp.status, 400)
91        self.assertEqual(resp.getheader('connection'), 'close')
92        self.assertEqual(resp.read(), b'')
93
94    def test_correct_upgrade_request_13(self):
95        for http_connection in ['Upgrade', 'UpGrAdE', 'keep-alive, Upgrade']:
96            connect = [
97                "GET /echo HTTP/1.1",
98                "Upgrade: websocket",
99                "Connection: %s" % http_connection,
100                "Host: %s:%s" % self.server_addr,
101                "Origin: http://%s:%s" % self.server_addr,
102                "Sec-WebSocket-Version: 13",
103                "Sec-WebSocket-Key: d9MXuOzlVQ0h+qRllvSCIg==",
104            ]
105            sock = eventlet.connect(self.server_addr)
106
107            sock.sendall(six.b('\r\n'.join(connect) + '\r\n\r\n'))
108            result = sock.recv(1024)
109            # The server responds the correct Websocket handshake
110            print('Connection string: %r' % http_connection)
111            self.assertEqual(result, six.b('\r\n'.join([
112                'HTTP/1.1 101 Switching Protocols',
113                'Upgrade: websocket',
114                'Connection: Upgrade',
115                'Sec-WebSocket-Accept: ywSyWXCPNsDxLrQdQrn5RFNRfBU=\r\n\r\n',
116            ])))
117
118    def test_send_recv_13(self):
119        connect = [
120            "GET /echo HTTP/1.1",
121            "Upgrade: websocket",
122            "Connection: Upgrade",
123            "Host: %s:%s" % self.server_addr,
124            "Origin: http://%s:%s" % self.server_addr,
125            "Sec-WebSocket-Version: 13",
126            "Sec-WebSocket-Key: d9MXuOzlVQ0h+qRllvSCIg==",
127        ]
128        sock = eventlet.connect(self.server_addr)
129        sock.sendall(six.b('\r\n'.join(connect) + '\r\n\r\n'))
130        sock.recv(1024)
131        ws = websocket.RFC6455WebSocket(sock, {}, client=True)
132        ws.send(b'hello')
133        assert ws.wait() == b'hello'
134        ws.send(b'hello world!\x01')
135        ws.send(u'hello world again!')
136        assert ws.wait() == b'hello world!\x01'
137        assert ws.wait() == u'hello world again!'
138        ws.close()
139        eventlet.sleep(0.01)
140
141    def test_breaking_the_connection_13(self):
142        error_detected = [False]
143        done_with_request = event.Event()
144        site = self.site
145
146        def error_detector(environ, start_response):
147            try:
148                try:
149                    return site(environ, start_response)
150                except:
151                    error_detected[0] = True
152                    raise
153            finally:
154                done_with_request.send(True)
155        self.site = error_detector
156        self.spawn_server()
157        connect = [
158            "GET /echo HTTP/1.1",
159            "Upgrade: websocket",
160            "Connection: Upgrade",
161            "Host: %s:%s" % self.server_addr,
162            "Origin: http://%s:%s" % self.server_addr,
163            "Sec-WebSocket-Version: 13",
164            "Sec-WebSocket-Key: d9MXuOzlVQ0h+qRllvSCIg==",
165        ]
166        sock = eventlet.connect(self.server_addr)
167        sock.sendall(six.b('\r\n'.join(connect) + '\r\n\r\n'))
168        sock.recv(1024)  # get the headers
169        sock.close()  # close while the app is running
170        done_with_request.wait()
171        assert not error_detected[0]
172
173    def test_client_closing_connection_13(self):
174        error_detected = [False]
175        done_with_request = event.Event()
176        site = self.site
177
178        def error_detector(environ, start_response):
179            try:
180                try:
181                    return site(environ, start_response)
182                except:
183                    error_detected[0] = True
184                    raise
185            finally:
186                done_with_request.send(True)
187        self.site = error_detector
188        self.spawn_server()
189        connect = [
190            "GET /echo HTTP/1.1",
191            "Upgrade: websocket",
192            "Connection: Upgrade",
193            "Host: %s:%s" % self.server_addr,
194            "Origin: http://%s:%s" % self.server_addr,
195            "Sec-WebSocket-Version: 13",
196            "Sec-WebSocket-Key: d9MXuOzlVQ0h+qRllvSCIg==",
197        ]
198        sock = eventlet.connect(self.server_addr)
199        sock.sendall(six.b('\r\n'.join(connect) + '\r\n\r\n'))
200        sock.recv(1024)  # get the headers
201        closeframe = struct.pack('!BBIH', 1 << 7 | 8, 1 << 7 | 2, 0, 1000)
202        sock.sendall(closeframe)  # "Close the connection" packet.
203        done_with_request.wait()
204        assert not error_detected[0]
205
206    def test_client_invalid_packet_13(self):
207        error_detected = [False]
208        done_with_request = event.Event()
209        site = self.site
210
211        def error_detector(environ, start_response):
212            try:
213                try:
214                    return site(environ, start_response)
215                except:
216                    error_detected[0] = True
217                    raise
218            finally:
219                done_with_request.send(True)
220        self.site = error_detector
221        self.spawn_server()
222        connect = [
223            "GET /echo HTTP/1.1",
224            "Upgrade: websocket",
225            "Connection: Upgrade",
226            "Host: %s:%s" % self.server_addr,
227            "Origin: http://%s:%s" % self.server_addr,
228            "Sec-WebSocket-Version: 13",
229            "Sec-WebSocket-Key: d9MXuOzlVQ0h+qRllvSCIg==",
230        ]
231        sock = eventlet.connect(self.server_addr)
232        sock.sendall(six.b('\r\n'.join(connect) + '\r\n\r\n'))
233        sock.recv(1024)  # get the headers
234        sock.sendall(b'\x07\xff')  # Weird packet.
235        done_with_request.wait()
236        assert not error_detected[0]
237
238
239class TestWebSocketWithCompression(tests.wsgi_test._TestBase):
240    TEST_TIMEOUT = 5
241
242    def set_site(self):
243        self.site = wsapp
244
245    def setUp(self):
246        super(TestWebSocketWithCompression, self).setUp()
247        self.connect = '\r\n'.join([
248            "GET /echo HTTP/1.1",
249            "Upgrade: websocket",
250            "Connection: upgrade",
251            "Host: %s:%s" % self.server_addr,
252            "Origin: http://%s:%s" % self.server_addr,
253            "Sec-WebSocket-Version: 13",
254            "Sec-WebSocket-Key: d9MXuOzlVQ0h+qRllvSCIg==",
255            "Sec-WebSocket-Extensions: %s",
256            '\r\n'
257        ])
258        self.handshake_re = re.compile(six.b('\r\n'.join([
259            'HTTP/1.1 101 Switching Protocols',
260            'Upgrade: websocket',
261            'Connection: Upgrade',
262            'Sec-WebSocket-Accept: ywSyWXCPNsDxLrQdQrn5RFNRfBU=',
263            'Sec-WebSocket-Extensions: (.+)'
264            '\r\n',
265        ])))
266
267    @staticmethod
268    def get_deflated_reply(ws):
269        msg = ws._recv_frame(None)
270        msg.decompressor = None
271        return msg.getvalue()
272
273    def test_accept_basic_deflate_ext_13(self):
274        for extension in [
275            'permessage-deflate',
276            'PeRMessAGe-dEFlaTe',
277        ]:
278            sock = eventlet.connect(self.server_addr)
279
280            sock.sendall(six.b(self.connect % extension))
281            result = sock.recv(1024)
282
283            # The server responds the correct Websocket handshake
284            # print('Extension offer: %r' % extension)
285            match = re.match(self.handshake_re, result)
286            assert match is not None
287            assert len(match.groups()) == 1
288
289    def test_accept_deflate_ext_context_takeover_13(self):
290        for extension in [
291            'permessage-deflate;CLient_No_conteXT_TAkeOver',
292            'permessage-deflate;   SerVER_No_conteXT_TAkeOver',
293            'permessage-deflate; server_no_context_takeover; client_no_context_takeover',
294        ]:
295            sock = eventlet.connect(self.server_addr)
296
297            sock.sendall(six.b(self.connect % extension))
298            result = sock.recv(1024)
299
300            # The server responds the correct Websocket handshake
301            # print('Extension offer: %r' % extension)
302            match = re.match(self.handshake_re, result)
303            assert match is not None
304            assert len(match.groups()) == 1
305            offered_ext_parts = (ex.strip().lower() for ex in extension.split(';'))
306            accepted_ext_parts = match.groups()[0].decode().split('; ')
307            assert all(oep in accepted_ext_parts for oep in offered_ext_parts)
308
309    def test_accept_deflate_ext_window_max_bits_13(self):
310        for extension_string, vals in [
311            ('permessage-deflate; client_max_window_bits', [15]),
312            ('permessage-deflate;   Server_Max_Window_Bits  =  11', [11]),
313            ('permessage-deflate; server_max_window_bits; '
314             'client_max_window_bits=9', [15, 9])
315        ]:
316            sock = eventlet.connect(self.server_addr)
317
318            sock.sendall(six.b(self.connect % extension_string))
319            result = sock.recv(1024)
320
321            # The server responds the correct Websocket handshake
322            # print('Extension offer: %r' % extension_string)
323            match = re.match(self.handshake_re, result)
324            assert match is not None
325            assert len(match.groups()) == 1
326
327            offered_parts = [part.strip().lower() for part in extension_string.split(';')]
328            offered_parts_names = [part.split('=')[0].strip() for part in offered_parts]
329            offered_parts_dict = dict(zip(offered_parts_names[1:], vals))
330
331            accepted_ext_parts = match.groups()[0].decode().split('; ')
332            assert accepted_ext_parts[0] == 'permessage-deflate'
333            for param, val in (part.split('=') for part in accepted_ext_parts[1:]):
334                assert int(val) == offered_parts_dict[param]
335
336    def test_reject_max_window_bits_out_of_range_13(self):
337        extension_string = ('permessage-deflate; client_max_window_bits=7,'
338                            'permessage-deflate; server_max_window_bits=16, '
339                            'permessage-deflate; client_max_window_bits=16; '
340                            'server_max_window_bits=7, '
341                            'permessage-deflate')
342        sock = eventlet.connect(self.server_addr)
343
344        sock.sendall(six.b(self.connect % extension_string))
345        result = sock.recv(1024)
346
347        # The server responds the correct Websocket handshake
348        # print('Extension offer: %r' % extension_string)
349        match = re.match(self.handshake_re, result)
350        assert match.groups()[0] == b'permessage-deflate'
351
352    def test_server_compress_with_context_takeover_13(self):
353        extensions_string = 'permessage-deflate; client_no_context_takeover;'
354        extensions = {'permessage-deflate': {
355            'client_no_context_takeover': True,
356            'server_no_context_takeover': False}}
357
358        sock = eventlet.connect(self.server_addr)
359        sock.sendall(six.b(self.connect % extensions_string))
360        sock.recv(1024)
361        ws = websocket.RFC6455WebSocket(sock, {}, client=True,
362                                        extensions=extensions)
363
364        # Deflated values taken from Section 7.2.3 of RFC 7692
365        # https://tools.ietf.org/html/rfc7692#section-7.2.3
366        ws.send(b'Hello')
367        msg1 = self.get_deflated_reply(ws)
368        assert msg1 == b'\xf2\x48\xcd\xc9\xc9\x07\x00'
369
370        ws.send(b'Hello')
371        msg2 = self.get_deflated_reply(ws)
372        assert msg2 == b'\xf2\x00\x11\x00\x00'
373
374        ws.close()
375        eventlet.sleep(0.01)
376
377    def test_server_compress_no_context_takeover_13(self):
378        extensions_string = 'permessage-deflate; server_no_context_takeover;'
379        extensions = {'permessage-deflate': {
380            'client_no_context_takeover': False,
381            'server_no_context_takeover': True}}
382
383        sock = eventlet.connect(self.server_addr)
384        sock.sendall(six.b(self.connect % extensions_string))
385        sock.recv(1024)
386        ws = websocket.RFC6455WebSocket(sock, {}, client=True,
387                                        extensions=extensions)
388
389        masked_msg1 = ws._pack_message(b'Hello', masked=True)
390        ws._send(masked_msg1)
391        masked_msg2 = ws._pack_message(b'Hello', masked=True)
392        ws._send(masked_msg2)
393        # Verify that client uses context takeover by checking
394        # that the second message
395        assert len(masked_msg2) < len(masked_msg1)
396
397        # Verify that server drops context between messages
398        # Deflated values taken from Section 7.2.3 of RFC 7692
399        # https://tools.ietf.org/html/rfc7692#section-7.2.3
400        reply_msg1 = self.get_deflated_reply(ws)
401        assert reply_msg1 == b'\xf2\x48\xcd\xc9\xc9\x07\x00'
402        reply_msg2 = self.get_deflated_reply(ws)
403        assert reply_msg2 == b'\xf2\x48\xcd\xc9\xc9\x07\x00'
404
405    def test_client_compress_with_context_takeover_13(self):
406        extensions = {'permessage-deflate': {
407            'client_no_context_takeover': False,
408            'server_no_context_takeover': True}}
409        ws = websocket.RFC6455WebSocket(None, {}, client=True,
410                                        extensions=extensions)
411
412        # Deflated values taken from Section 7.2.3 of RFC 7692
413        # modified opcode to Binary instead of Text
414        # https://tools.ietf.org/html/rfc7692#section-7.2.3
415        packed_msg_1 = ws._pack_message(b'Hello', masked=False)
416        assert packed_msg_1 == b'\xc2\x07\xf2\x48\xcd\xc9\xc9\x07\x00'
417        packed_msg_2 = ws._pack_message(b'Hello', masked=False)
418        assert packed_msg_2 == b'\xc2\x05\xf2\x00\x11\x00\x00'
419
420        eventlet.sleep(0.01)
421
422    def test_client_compress_no_context_takeover_13(self):
423        extensions = {'permessage-deflate': {
424            'client_no_context_takeover': True,
425            'server_no_context_takeover': False}}
426        ws = websocket.RFC6455WebSocket(None, {}, client=True,
427                                        extensions=extensions)
428
429        # Deflated values taken from Section 7.2.3 of RFC 7692
430        # modified opcode to Binary instead of Text
431        # https://tools.ietf.org/html/rfc7692#section-7.2.3
432        packed_msg_1 = ws._pack_message(b'Hello', masked=False)
433        assert packed_msg_1 == b'\xc2\x07\xf2\x48\xcd\xc9\xc9\x07\x00'
434        packed_msg_2 = ws._pack_message(b'Hello', masked=False)
435        assert packed_msg_2 == b'\xc2\x07\xf2\x48\xcd\xc9\xc9\x07\x00'
436
437    def test_compressed_send_recv_13(self):
438        extensions_string = 'permessage-deflate'
439        extensions = {'permessage-deflate': {
440            'client_no_context_takeover': False,
441            'server_no_context_takeover': False}}
442
443        sock = eventlet.connect(self.server_addr)
444        sock.sendall(six.b(self.connect % extensions_string))
445        sock.recv(1024)
446        ws = websocket.RFC6455WebSocket(sock, {}, client=True, extensions=extensions)
447
448        ws.send(b'hello')
449        assert ws.wait() == b'hello'
450        ws.send(b'hello world!')
451        ws.send(u'hello world again!')
452        assert ws.wait() == b'hello world!'
453        assert ws.wait() == u'hello world again!'
454
455        ws.close()
456        eventlet.sleep(0.01)
457
458    def test_send_uncompressed_msg_13(self):
459        extensions_string = 'permessage-deflate'
460        extensions = {'permessage-deflate': {
461            'client_no_context_takeover': False,
462            'server_no_context_takeover': False}}
463
464        sock = eventlet.connect(self.server_addr)
465        sock.sendall(six.b(self.connect % extensions_string))
466        sock.recv(1024)
467
468        # Send without using deflate, having rsv1 unset
469        ws = websocket.RFC6455WebSocket(sock, {}, client=True)
470        ws.send(b'Hello')
471
472        # Adding extensions to recognise deflated response
473        ws.extensions = extensions
474        assert ws.wait() == b'Hello'
475
476        ws.close()
477        eventlet.sleep(0.01)
478
479    def test_compressed_send_recv_client_no_context_13(self):
480        extensions_string = 'permessage-deflate; client_no_context_takeover'
481        extensions = {'permessage-deflate': {
482            'client_no_context_takeover': True,
483            'server_no_context_takeover': False}}
484
485        sock = eventlet.connect(self.server_addr)
486        sock.sendall(six.b(self.connect % extensions_string))
487        sock.recv(1024)
488        ws = websocket.RFC6455WebSocket(sock, {}, client=True, extensions=extensions)
489
490        ws.send(b'hello')
491        assert ws.wait() == b'hello'
492        ws.send(b'hello world!')
493        ws.send(u'hello world again!')
494        assert ws.wait() == b'hello world!'
495        assert ws.wait() == u'hello world again!'
496
497        ws.close()
498        eventlet.sleep(0.01)
499
500    def test_compressed_send_recv_server_no_context_13(self):
501        extensions_string = 'permessage-deflate; server_no_context_takeover'
502        extensions = {'permessage-deflate': {
503            'client_no_context_takeover': False,
504            'server_no_context_takeover': False}}
505
506        sock = eventlet.connect(self.server_addr)
507        sock.sendall(six.b(self.connect % extensions_string))
508        sock.recv(1024)
509        ws = websocket.RFC6455WebSocket(sock, {}, client=True, extensions=extensions)
510
511        ws.send(b'hello')
512        assert ws.wait() == b'hello'
513        ws.send(b'hello world!')
514        ws.send(u'hello world again!')
515        assert ws.wait() == b'hello world!'
516        assert ws.wait() == u'hello world again!'
517
518        ws.close()
519        eventlet.sleep(0.01)
520
521    def test_compressed_send_recv_both_no_context_13(self):
522        extensions_string = ('permessage-deflate;'
523                             ' server_no_context_takeover; client_no_context_takeover')
524        extensions = {'permessage-deflate': {
525            'client_no_context_takeover': True,
526            'server_no_context_takeover': True}}
527
528        sock = eventlet.connect(self.server_addr)
529        sock.sendall(six.b(self.connect % extensions_string))
530        sock.recv(1024)
531        ws = websocket.RFC6455WebSocket(sock, {}, client=True, extensions=extensions)
532
533        ws.send(b'hello')
534        assert ws.wait() == b'hello'
535        ws.send(b'hello world!')
536        ws.send(u'hello world again!')
537        assert ws.wait() == b'hello world!'
538        assert ws.wait() == u'hello world again!'
539
540        ws.close()
541        eventlet.sleep(0.01)
542
543    def test_large_frame_size_compressed_13(self):
544        # Test fix for GHSA-9p9m-jm8w-94p2
545        extensions_string = 'permessage-deflate'
546        extensions = {'permessage-deflate': {
547            'client_no_context_takeover': False,
548            'server_no_context_takeover': False}}
549
550        sock = eventlet.connect(self.server_addr)
551        sock.sendall(six.b(self.connect % extensions_string))
552        sock.recv(1024)
553        ws = websocket.RFC6455WebSocket(sock, {}, client=True, extensions=extensions)
554
555        should_still_fit = b"x" * TEST_MAX_FRAME_LENGTH
556        one_too_much = should_still_fit + b"x"
557
558        # send just fitting frame twice to make sure they are fine independently
559        ws.send(should_still_fit)
560        assert ws.wait() == should_still_fit
561        ws.send(should_still_fit)
562        assert ws.wait() == should_still_fit
563        ws.send(one_too_much)
564
565        res = ws.wait()
566        assert res is None # socket closed
567        # TODO: The websocket currently sents compressed control frames, which contradicts RFC7692.
568        # Renable the following assert after that has been fixed.
569        # assert ws._remote_close_data == b"\x03\xf1Incoming compressed frame is above length limit."
570        eventlet.sleep(0.01)
571
572    def test_large_frame_size_uncompressed_13(self):
573        # Test fix for GHSA-9p9m-jm8w-94p2
574        sock = eventlet.connect(self.server_addr)
575        sock.sendall(six.b(self.connect))
576        sock.recv(1024)
577        ws = websocket.RFC6455WebSocket(sock, {}, client=True)
578
579        should_still_fit = b"x" * TEST_MAX_FRAME_LENGTH
580        one_too_much = should_still_fit + b"x"
581
582        # send just fitting frame twice to make sure they are fine independently
583        ws.send(should_still_fit)
584        assert ws.wait() == should_still_fit
585        ws.send(should_still_fit)
586        assert ws.wait() == should_still_fit
587        ws.send(one_too_much)
588
589        res = ws.wait()
590        assert res is None # socket closed
591        # close code should be available now
592        assert ws._remote_close_data == b"\x03\xf1Incoming frame of 50001 bytes is above length limit of 50000 bytes."
593        eventlet.sleep(0.01)
594