1# HTTP websocket server functional tests
2
3import asyncio
4
5import pytest
6
7import aiohttp
8from aiohttp import web
9from aiohttp.http import WSMsgType
10
11
12async def test_websocket_can_prepare(loop, aiohttp_client) -> None:
13    async def handler(request):
14        ws = web.WebSocketResponse()
15        if not ws.can_prepare(request):
16            raise web.HTTPUpgradeRequired()
17
18        return web.Response()
19
20    app = web.Application()
21    app.router.add_route("GET", "/", handler)
22    client = await aiohttp_client(app)
23
24    resp = await client.get("/")
25    assert resp.status == 426
26
27
28async def test_websocket_json(loop, aiohttp_client) -> None:
29    async def handler(request):
30        ws = web.WebSocketResponse()
31        if not ws.can_prepare(request):
32            return web.HTTPUpgradeRequired()
33
34        await ws.prepare(request)
35        msg = await ws.receive()
36
37        msg_json = msg.json()
38        answer = msg_json["test"]
39        await ws.send_str(answer)
40
41        await ws.close()
42        return ws
43
44    app = web.Application()
45    app.router.add_route("GET", "/", handler)
46    client = await aiohttp_client(app)
47
48    ws = await client.ws_connect("/")
49    expected_value = "value"
50    payload = '{"test": "%s"}' % expected_value
51    await ws.send_str(payload)
52
53    resp = await ws.receive()
54    assert resp.data == expected_value
55
56
57async def test_websocket_json_invalid_message(loop, aiohttp_client) -> None:
58    async def handler(request):
59        ws = web.WebSocketResponse()
60        await ws.prepare(request)
61        try:
62            await ws.receive_json()
63        except ValueError:
64            await ws.send_str("ValueError was raised")
65        else:
66            raise Exception("No Exception")
67        finally:
68            await ws.close()
69        return ws
70
71    app = web.Application()
72    app.router.add_route("GET", "/", handler)
73    client = await aiohttp_client(app)
74
75    ws = await client.ws_connect("/")
76    payload = "NOT A VALID JSON STRING"
77    await ws.send_str(payload)
78
79    data = await ws.receive_str()
80    assert "ValueError was raised" in data
81
82
83async def test_websocket_send_json(loop, aiohttp_client) -> None:
84    async def handler(request):
85        ws = web.WebSocketResponse()
86        await ws.prepare(request)
87
88        data = await ws.receive_json()
89        await ws.send_json(data)
90
91        await ws.close()
92        return ws
93
94    app = web.Application()
95    app.router.add_route("GET", "/", handler)
96    client = await aiohttp_client(app)
97
98    ws = await client.ws_connect("/")
99    expected_value = "value"
100    await ws.send_json({"test": expected_value})
101
102    data = await ws.receive_json()
103    assert data["test"] == expected_value
104
105
106async def test_websocket_receive_json(loop, aiohttp_client) -> None:
107    async def handler(request):
108        ws = web.WebSocketResponse()
109        await ws.prepare(request)
110
111        data = await ws.receive_json()
112        answer = data["test"]
113        await ws.send_str(answer)
114
115        await ws.close()
116        return ws
117
118    app = web.Application()
119    app.router.add_route("GET", "/", handler)
120    client = await aiohttp_client(app)
121
122    ws = await client.ws_connect("/")
123    expected_value = "value"
124    payload = '{"test": "%s"}' % expected_value
125    await ws.send_str(payload)
126
127    resp = await ws.receive()
128    assert resp.data == expected_value
129
130
131async def test_send_recv_text(loop, aiohttp_client) -> None:
132
133    closed = loop.create_future()
134
135    async def handler(request):
136        ws = web.WebSocketResponse()
137        await ws.prepare(request)
138        msg = await ws.receive_str()
139        await ws.send_str(msg + "/answer")
140        await ws.close()
141        closed.set_result(1)
142        return ws
143
144    app = web.Application()
145    app.router.add_route("GET", "/", handler)
146    client = await aiohttp_client(app)
147
148    ws = await client.ws_connect("/")
149    await ws.send_str("ask")
150    msg = await ws.receive()
151    assert msg.type == aiohttp.WSMsgType.TEXT
152    assert "ask/answer" == msg.data
153
154    msg = await ws.receive()
155    assert msg.type == aiohttp.WSMsgType.CLOSE
156    assert msg.data == 1000
157    assert msg.extra == ""
158
159    assert ws.closed
160    assert ws.close_code == 1000
161
162    await closed
163
164
165async def test_send_recv_bytes(loop, aiohttp_client) -> None:
166
167    closed = loop.create_future()
168
169    async def handler(request):
170        ws = web.WebSocketResponse()
171        await ws.prepare(request)
172
173        msg = await ws.receive_bytes()
174        await ws.send_bytes(msg + b"/answer")
175        await ws.close()
176        closed.set_result(1)
177        return ws
178
179    app = web.Application()
180    app.router.add_route("GET", "/", handler)
181    client = await aiohttp_client(app)
182
183    ws = await client.ws_connect("/")
184    await ws.send_bytes(b"ask")
185    msg = await ws.receive()
186    assert msg.type == aiohttp.WSMsgType.BINARY
187    assert b"ask/answer" == msg.data
188
189    msg = await ws.receive()
190    assert msg.type == aiohttp.WSMsgType.CLOSE
191    assert msg.data == 1000
192    assert msg.extra == ""
193
194    assert ws.closed
195    assert ws.close_code == 1000
196
197    await closed
198
199
200async def test_send_recv_json(loop, aiohttp_client) -> None:
201    closed = loop.create_future()
202
203    async def handler(request):
204        ws = web.WebSocketResponse()
205        await ws.prepare(request)
206        data = await ws.receive_json()
207        await ws.send_json({"response": data["request"]})
208        await ws.close()
209        closed.set_result(1)
210        return ws
211
212    app = web.Application()
213    app.router.add_route("GET", "/", handler)
214    client = await aiohttp_client(app)
215
216    ws = await client.ws_connect("/")
217
218    await ws.send_str('{"request": "test"}')
219    msg = await ws.receive()
220    data = msg.json()
221    assert msg.type == aiohttp.WSMsgType.TEXT
222    assert data["response"] == "test"
223
224    msg = await ws.receive()
225    assert msg.type == aiohttp.WSMsgType.CLOSE
226    assert msg.data == 1000
227    assert msg.extra == ""
228
229    await ws.close()
230
231    await closed
232
233
234async def test_close_timeout(loop, aiohttp_client) -> None:
235    aborted = loop.create_future()
236    elapsed = 1e10  # something big
237
238    async def handler(request):
239        nonlocal elapsed
240        ws = web.WebSocketResponse(timeout=0.1)
241        await ws.prepare(request)
242        assert "request" == (await ws.receive_str())
243        await ws.send_str("reply")
244        begin = ws._loop.time()
245        assert await ws.close()
246        elapsed = ws._loop.time() - begin
247        assert ws.close_code == 1006
248        assert isinstance(ws.exception(), asyncio.TimeoutError)
249        aborted.set_result(1)
250        return ws
251
252    app = web.Application()
253    app.router.add_route("GET", "/", handler)
254    client = await aiohttp_client(app)
255
256    ws = await client.ws_connect("/")
257    await ws.send_str("request")
258    assert "reply" == (await ws.receive_str())
259
260    # The server closes here.  Then the client sends bogus messages with an
261    # internval shorter than server-side close timeout, to make the server
262    # hanging indefinitely.
263    await asyncio.sleep(0.08)
264    msg = await ws._reader.read()
265    assert msg.type == WSMsgType.CLOSE
266
267    await asyncio.sleep(0.08)
268    assert await aborted
269
270    assert elapsed < 0.25, "close() should have returned before " "at most 2x timeout."
271
272    await ws.close()
273
274
275async def test_concurrent_close(loop, aiohttp_client) -> None:
276
277    srv_ws = None
278
279    async def handler(request):
280        nonlocal srv_ws
281        ws = srv_ws = web.WebSocketResponse(autoclose=False, protocols=("foo", "bar"))
282        await ws.prepare(request)
283
284        msg = await ws.receive()
285        assert msg.type == WSMsgType.CLOSING
286
287        msg = await ws.receive()
288        assert msg.type == WSMsgType.CLOSING
289
290        await asyncio.sleep(0)
291
292        msg = await ws.receive()
293        assert msg.type == WSMsgType.CLOSED
294
295        return ws
296
297    app = web.Application()
298    app.router.add_get("/", handler)
299    client = await aiohttp_client(app)
300
301    ws = await client.ws_connect("/", autoclose=False, protocols=("eggs", "bar"))
302
303    await srv_ws.close(code=1007)
304
305    msg = await ws.receive()
306    assert msg.type == WSMsgType.CLOSE
307
308    await asyncio.sleep(0)
309    msg = await ws.receive()
310    assert msg.type == WSMsgType.CLOSED
311
312
313async def test_auto_pong_with_closing_by_peer(loop, aiohttp_client) -> None:
314
315    closed = loop.create_future()
316
317    async def handler(request):
318        ws = web.WebSocketResponse()
319        await ws.prepare(request)
320        await ws.receive()
321
322        msg = await ws.receive()
323        assert msg.type == WSMsgType.CLOSE
324        assert msg.data == 1000
325        assert msg.extra == "exit message"
326        closed.set_result(None)
327        return ws
328
329    app = web.Application()
330    app.router.add_get("/", handler)
331    client = await aiohttp_client(app)
332
333    ws = await client.ws_connect("/", autoclose=False, autoping=False)
334    await ws.ping()
335    await ws.send_str("ask")
336
337    msg = await ws.receive()
338    assert msg.type == WSMsgType.PONG
339    await ws.close(code=1000, message="exit message")
340    await closed
341
342
343async def test_ping(loop, aiohttp_client) -> None:
344
345    closed = loop.create_future()
346
347    async def handler(request):
348        ws = web.WebSocketResponse()
349        await ws.prepare(request)
350
351        await ws.ping("data")
352        await ws.receive()
353        closed.set_result(None)
354        return ws
355
356    app = web.Application()
357    app.router.add_get("/", handler)
358    client = await aiohttp_client(app)
359
360    ws = await client.ws_connect("/", autoping=False)
361
362    msg = await ws.receive()
363    assert msg.type == WSMsgType.PING
364    assert msg.data == b"data"
365    await ws.pong()
366    await ws.close()
367    await closed
368
369
370async def aiohttp_client_ping(loop, aiohttp_client):
371
372    closed = loop.create_future()
373
374    async def handler(request):
375        ws = web.WebSocketResponse()
376        await ws.prepare(request)
377
378        await ws.receive()
379        closed.set_result(None)
380        return ws
381
382    app = web.Application()
383    app.router.add_get("/", handler)
384    client = await aiohttp_client(app)
385
386    ws = await client.ws_connect("/", autoping=False)
387
388    await ws.ping("data")
389    msg = await ws.receive()
390    assert msg.type == WSMsgType.PONG
391    assert msg.data == b"data"
392    await ws.pong()
393    await ws.close()
394
395
396async def test_pong(loop, aiohttp_client) -> None:
397
398    closed = loop.create_future()
399
400    async def handler(request):
401        ws = web.WebSocketResponse(autoping=False)
402        await ws.prepare(request)
403
404        msg = await ws.receive()
405        assert msg.type == WSMsgType.PING
406        await ws.pong("data")
407
408        msg = await ws.receive()
409        assert msg.type == WSMsgType.CLOSE
410        assert msg.data == 1000
411        assert msg.extra == "exit message"
412        closed.set_result(None)
413        return ws
414
415    app = web.Application()
416    app.router.add_get("/", handler)
417    client = await aiohttp_client(app)
418
419    ws = await client.ws_connect("/", autoping=False)
420
421    await ws.ping("data")
422    msg = await ws.receive()
423    assert msg.type == WSMsgType.PONG
424    assert msg.data == b"data"
425
426    await ws.close(code=1000, message="exit message")
427
428    await closed
429
430
431async def test_change_status(loop, aiohttp_client) -> None:
432
433    closed = loop.create_future()
434
435    async def handler(request):
436        ws = web.WebSocketResponse()
437        ws.set_status(200)
438        assert 200 == ws.status
439        await ws.prepare(request)
440        assert 101 == ws.status
441        await ws.close()
442        closed.set_result(None)
443        return ws
444
445    app = web.Application()
446    app.router.add_get("/", handler)
447    client = await aiohttp_client(app)
448
449    ws = await client.ws_connect("/", autoping=False)
450
451    await ws.close()
452    await closed
453    await ws.close()
454
455
456async def test_handle_protocol(loop, aiohttp_client) -> None:
457
458    closed = loop.create_future()
459
460    async def handler(request):
461        ws = web.WebSocketResponse(protocols=("foo", "bar"))
462        await ws.prepare(request)
463        await ws.close()
464        assert "bar" == ws.ws_protocol
465        closed.set_result(None)
466        return ws
467
468    app = web.Application()
469    app.router.add_get("/", handler)
470    client = await aiohttp_client(app)
471
472    ws = await client.ws_connect("/", protocols=("eggs", "bar"))
473
474    await ws.close()
475    await closed
476
477
478async def test_server_close_handshake(loop, aiohttp_client) -> None:
479
480    closed = loop.create_future()
481
482    async def handler(request):
483        ws = web.WebSocketResponse(protocols=("foo", "bar"))
484        await ws.prepare(request)
485        await ws.close()
486        closed.set_result(None)
487        return ws
488
489    app = web.Application()
490    app.router.add_get("/", handler)
491    client = await aiohttp_client(app)
492
493    ws = await client.ws_connect("/", autoclose=False, protocols=("eggs", "bar"))
494
495    msg = await ws.receive()
496    assert msg.type == WSMsgType.CLOSE
497    await ws.close()
498    await closed
499
500
501async def aiohttp_client_close_handshake(loop, aiohttp_client):
502
503    closed = loop.create_future()
504
505    async def handler(request):
506        ws = web.WebSocketResponse(autoclose=False, protocols=("foo", "bar"))
507        await ws.prepare(request)
508
509        msg = await ws.receive()
510        assert msg.type == WSMsgType.CLOSE
511        assert not ws.closed
512        await ws.close()
513        assert ws.closed
514        assert ws.close_code == 1007
515
516        msg = await ws.receive()
517        assert msg.type == WSMsgType.CLOSED
518
519        closed.set_result(None)
520        return ws
521
522    app = web.Application()
523    app.router.add_get("/", handler)
524    client = await aiohttp_client(app)
525
526    ws = await client.ws_connect("/", autoclose=False, protocols=("eggs", "bar"))
527
528    await ws.close(code=1007)
529    msg = await ws.receive()
530    assert msg.type == WSMsgType.CLOSED
531    await closed
532
533
534async def test_server_close_handshake_server_eats_client_messages(loop, aiohttp_client):
535    closed = loop.create_future()
536
537    async def handler(request):
538        ws = web.WebSocketResponse(protocols=("foo", "bar"))
539        await ws.prepare(request)
540        await ws.close()
541        closed.set_result(None)
542        return ws
543
544    app = web.Application()
545    app.router.add_get("/", handler)
546    client = await aiohttp_client(app)
547
548    ws = await client.ws_connect(
549        "/", autoclose=False, autoping=False, protocols=("eggs", "bar")
550    )
551
552    msg = await ws.receive()
553    assert msg.type == WSMsgType.CLOSE
554
555    await ws.send_str("text")
556    await ws.send_bytes(b"bytes")
557    await ws.ping()
558
559    await ws.close()
560    await closed
561
562
563async def test_receive_timeout(loop, aiohttp_client) -> None:
564    raised = False
565
566    async def handler(request):
567        ws = web.WebSocketResponse(receive_timeout=0.1)
568        await ws.prepare(request)
569
570        try:
571            await ws.receive()
572        except asyncio.TimeoutError:
573            nonlocal raised
574            raised = True
575
576        await ws.close()
577        return ws
578
579    app = web.Application()
580    app.router.add_get("/", handler)
581    client = await aiohttp_client(app)
582
583    ws = await client.ws_connect("/")
584    await ws.receive()
585    await ws.close()
586    assert raised
587
588
589async def test_custom_receive_timeout(loop, aiohttp_client) -> None:
590    raised = False
591
592    async def handler(request):
593        ws = web.WebSocketResponse(receive_timeout=None)
594        await ws.prepare(request)
595
596        try:
597            await ws.receive(0.1)
598        except asyncio.TimeoutError:
599            nonlocal raised
600            raised = True
601
602        await ws.close()
603        return ws
604
605    app = web.Application()
606    app.router.add_get("/", handler)
607    client = await aiohttp_client(app)
608
609    ws = await client.ws_connect("/")
610    await ws.receive()
611    await ws.close()
612    assert raised
613
614
615async def test_heartbeat(loop, aiohttp_client) -> None:
616    async def handler(request):
617        ws = web.WebSocketResponse(heartbeat=0.05)
618        await ws.prepare(request)
619        await ws.receive()
620        await ws.close()
621        return ws
622
623    app = web.Application()
624    app.router.add_get("/", handler)
625
626    client = await aiohttp_client(app)
627    ws = await client.ws_connect("/", autoping=False)
628    msg = await ws.receive()
629
630    assert msg.type == aiohttp.WSMsgType.ping
631
632    await ws.close()
633
634
635async def test_heartbeat_no_pong(loop, aiohttp_client) -> None:
636    async def handler(request):
637        ws = web.WebSocketResponse(heartbeat=0.05)
638        await ws.prepare(request)
639
640        await ws.receive()
641        return ws
642
643    app = web.Application()
644    app.router.add_get("/", handler)
645
646    client = await aiohttp_client(app)
647    ws = await client.ws_connect("/", autoping=False)
648    msg = await ws.receive()
649    assert msg.type == aiohttp.WSMsgType.ping
650    await ws.close()
651
652
653async def test_server_ws_async_for(loop, aiohttp_server) -> None:
654    closed = loop.create_future()
655
656    async def handler(request):
657        ws = web.WebSocketResponse()
658        await ws.prepare(request)
659        async for msg in ws:
660            assert msg.type == aiohttp.WSMsgType.TEXT
661            s = msg.data
662            await ws.send_str(s + "/answer")
663        await ws.close()
664        closed.set_result(1)
665        return ws
666
667    app = web.Application()
668    app.router.add_route("GET", "/", handler)
669    server = await aiohttp_server(app)
670
671    async with aiohttp.ClientSession() as sm:
672        async with sm.ws_connect(server.make_url("/")) as resp:
673
674            items = ["q1", "q2", "q3"]
675            for item in items:
676                await resp.send_str(item)
677                msg = await resp.receive()
678                assert msg.type == aiohttp.WSMsgType.TEXT
679                assert item + "/answer" == msg.data
680
681            await resp.close()
682            await closed
683
684
685async def test_closed_async_for(loop, aiohttp_client) -> None:
686
687    closed = loop.create_future()
688
689    async def handler(request):
690        ws = web.WebSocketResponse()
691        await ws.prepare(request)
692
693        messages = []
694        async for msg in ws:
695            messages.append(msg)
696            if "stop" == msg.data:
697                await ws.send_str("stopping")
698                await ws.close()
699
700        assert 1 == len(messages)
701        assert messages[0].type == WSMsgType.TEXT
702        assert messages[0].data == "stop"
703
704        closed.set_result(None)
705        return ws
706
707    app = web.Application()
708    app.router.add_get("/", handler)
709    client = await aiohttp_client(app)
710
711    ws = await client.ws_connect("/")
712    await ws.send_str("stop")
713    msg = await ws.receive()
714    assert msg.type == WSMsgType.TEXT
715    assert msg.data == "stopping"
716
717    await ws.close()
718    await closed
719
720
721async def test_websocket_disable_keepalive(loop, aiohttp_client) -> None:
722    async def handler(request):
723        ws = web.WebSocketResponse()
724        if not ws.can_prepare(request):
725            return web.Response(text="OK")
726        assert request.protocol._keepalive
727        await ws.prepare(request)
728        assert not request.protocol._keepalive
729        assert not request.protocol._keepalive_handle
730
731        await ws.send_str("OK")
732        await ws.close()
733        return ws
734
735    app = web.Application()
736    app.router.add_route("GET", "/", handler)
737    client = await aiohttp_client(app)
738
739    resp = await client.get("/")
740    txt = await resp.text()
741    assert txt == "OK"
742
743    ws = await client.ws_connect("/")
744    data = await ws.receive_str()
745    assert data == "OK"
746
747
748async def test_bug3380(loop, aiohttp_client) -> None:
749    async def handle_null(request):
750        return aiohttp.web.json_response({"err": None})
751
752    async def ws_handler(request):
753        return web.Response(status=401)
754
755    app = web.Application()
756    app.router.add_route("GET", "/ws", ws_handler)
757    app.router.add_route("GET", "/api/null", handle_null)
758
759    client = await aiohttp_client(app)
760
761    resp = await client.get("/api/null")
762    assert (await resp.json()) == {"err": None}
763    resp.close()
764
765    with pytest.raises(aiohttp.WSServerHandshakeError):
766        await client.ws_connect("/ws")
767
768    resp = await client.get("/api/null", timeout=1)
769    assert (await resp.json()) == {"err": None}
770    resp.close()
771