1# XX this should get broken up, like testing.py did
2
3import tempfile
4
5import pytest
6
7from .._core.tests.tutil import can_bind_ipv6
8from .. import sleep
9from .. import _core
10from .._highlevel_generic import aclose_forcefully
11from ..testing import *
12from ..testing._check_streams import _assert_raises
13from ..testing._memory_streams import _UnboundedByteQueue
14from .. import socket as tsocket
15from .._highlevel_socket import SocketListener
16
17
18async def test_wait_all_tasks_blocked():
19    record = []
20
21    async def busy_bee():
22        for _ in range(10):
23            await _core.checkpoint()
24        record.append("busy bee exhausted")
25
26    async def waiting_for_bee_to_leave():
27        await wait_all_tasks_blocked()
28        record.append("quiet at last!")
29
30    async with _core.open_nursery() as nursery:
31        nursery.start_soon(busy_bee)
32        nursery.start_soon(waiting_for_bee_to_leave)
33        nursery.start_soon(waiting_for_bee_to_leave)
34
35    # check cancellation
36    record = []
37
38    async def cancelled_while_waiting():
39        try:
40            await wait_all_tasks_blocked()
41        except _core.Cancelled:
42            record.append("ok")
43
44    async with _core.open_nursery() as nursery:
45        nursery.start_soon(cancelled_while_waiting)
46        nursery.cancel_scope.cancel()
47    assert record == ["ok"]
48
49
50async def test_wait_all_tasks_blocked_with_timeouts(mock_clock):
51    record = []
52
53    async def timeout_task():
54        record.append("tt start")
55        await sleep(5)
56        record.append("tt finished")
57
58    async with _core.open_nursery() as nursery:
59        nursery.start_soon(timeout_task)
60        await wait_all_tasks_blocked()
61        assert record == ["tt start"]
62        mock_clock.jump(10)
63        await wait_all_tasks_blocked()
64        assert record == ["tt start", "tt finished"]
65
66
67async def test_wait_all_tasks_blocked_with_cushion():
68    record = []
69
70    async def blink():
71        record.append("blink start")
72        await sleep(0.01)
73        await sleep(0.01)
74        await sleep(0.01)
75        record.append("blink end")
76
77    async def wait_no_cushion():
78        await wait_all_tasks_blocked()
79        record.append("wait_no_cushion end")
80
81    async def wait_small_cushion():
82        await wait_all_tasks_blocked(0.02)
83        record.append("wait_small_cushion end")
84
85    async def wait_big_cushion():
86        await wait_all_tasks_blocked(0.03)
87        record.append("wait_big_cushion end")
88
89    async with _core.open_nursery() as nursery:
90        nursery.start_soon(blink)
91        nursery.start_soon(wait_no_cushion)
92        nursery.start_soon(wait_small_cushion)
93        nursery.start_soon(wait_small_cushion)
94        nursery.start_soon(wait_big_cushion)
95
96    assert record == [
97        "blink start",
98        "wait_no_cushion end",
99        "blink end",
100        "wait_small_cushion end",
101        "wait_small_cushion end",
102        "wait_big_cushion end",
103    ]
104
105
106################################################################
107
108
109async def test_assert_checkpoints(recwarn):
110    with assert_checkpoints():
111        await _core.checkpoint()
112
113    with pytest.raises(AssertionError):
114        with assert_checkpoints():
115            1 + 1
116
117    # partial yield cases
118    # if you have a schedule point but not a cancel point, or vice-versa, then
119    # that's not a checkpoint.
120    for partial_yield in [
121        _core.checkpoint_if_cancelled,
122        _core.cancel_shielded_checkpoint,
123    ]:
124        print(partial_yield)
125        with pytest.raises(AssertionError):
126            with assert_checkpoints():
127                await partial_yield()
128
129    # But both together count as a checkpoint
130    with assert_checkpoints():
131        await _core.checkpoint_if_cancelled()
132        await _core.cancel_shielded_checkpoint()
133
134
135async def test_assert_no_checkpoints(recwarn):
136    with assert_no_checkpoints():
137        1 + 1
138
139    with pytest.raises(AssertionError):
140        with assert_no_checkpoints():
141            await _core.checkpoint()
142
143    # partial yield cases
144    # if you have a schedule point but not a cancel point, or vice-versa, then
145    # that doesn't make *either* version of assert_{no_,}yields happy.
146    for partial_yield in [
147        _core.checkpoint_if_cancelled,
148        _core.cancel_shielded_checkpoint,
149    ]:
150        print(partial_yield)
151        with pytest.raises(AssertionError):
152            with assert_no_checkpoints():
153                await partial_yield()
154
155    # And both together also count as a checkpoint
156    with pytest.raises(AssertionError):
157        with assert_no_checkpoints():
158            await _core.checkpoint_if_cancelled()
159            await _core.cancel_shielded_checkpoint()
160
161
162################################################################
163
164
165async def test_Sequencer():
166    record = []
167
168    def t(val):
169        print(val)
170        record.append(val)
171
172    async def f1(seq):
173        async with seq(1):
174            t(("f1", 1))
175        async with seq(3):
176            t(("f1", 3))
177        async with seq(4):
178            t(("f1", 4))
179
180    async def f2(seq):
181        async with seq(0):
182            t(("f2", 0))
183        async with seq(2):
184            t(("f2", 2))
185
186    seq = Sequencer()
187    async with _core.open_nursery() as nursery:
188        nursery.start_soon(f1, seq)
189        nursery.start_soon(f2, seq)
190        async with seq(5):
191            await wait_all_tasks_blocked()
192        assert record == [("f2", 0), ("f1", 1), ("f2", 2), ("f1", 3), ("f1", 4)]
193
194    seq = Sequencer()
195    # Catches us if we try to re-use a sequence point:
196    async with seq(0):
197        pass
198    with pytest.raises(RuntimeError):
199        async with seq(0):
200            pass  # pragma: no cover
201
202
203async def test_Sequencer_cancel():
204    # Killing a blocked task makes everything blow up
205    record = []
206    seq = Sequencer()
207
208    async def child(i):
209        with _core.CancelScope() as scope:
210            if i == 1:
211                scope.cancel()
212            try:
213                async with seq(i):
214                    pass  # pragma: no cover
215            except RuntimeError:
216                record.append("seq({}) RuntimeError".format(i))
217
218    async with _core.open_nursery() as nursery:
219        nursery.start_soon(child, 1)
220        nursery.start_soon(child, 2)
221        async with seq(0):
222            pass  # pragma: no cover
223
224    assert record == ["seq(1) RuntimeError", "seq(2) RuntimeError"]
225
226    # Late arrivals also get errors
227    with pytest.raises(RuntimeError):
228        async with seq(3):
229            pass  # pragma: no cover
230
231
232################################################################
233
234
235async def test__assert_raises():
236    with pytest.raises(AssertionError):
237        with _assert_raises(RuntimeError):
238            1 + 1
239
240    with pytest.raises(TypeError):
241        with _assert_raises(RuntimeError):
242            "foo" + 1
243
244    with _assert_raises(RuntimeError):
245        raise RuntimeError
246
247
248# This is a private implementation detail, but it's complex enough to be worth
249# testing directly
250async def test__UnboundeByteQueue():
251    ubq = _UnboundedByteQueue()
252
253    ubq.put(b"123")
254    ubq.put(b"456")
255    assert ubq.get_nowait(1) == b"1"
256    assert ubq.get_nowait(10) == b"23456"
257    ubq.put(b"789")
258    assert ubq.get_nowait() == b"789"
259
260    with pytest.raises(_core.WouldBlock):
261        ubq.get_nowait(10)
262    with pytest.raises(_core.WouldBlock):
263        ubq.get_nowait()
264
265    with pytest.raises(TypeError):
266        ubq.put("string")
267
268    ubq.put(b"abc")
269    with assert_checkpoints():
270        assert await ubq.get(10) == b"abc"
271    ubq.put(b"def")
272    ubq.put(b"ghi")
273    with assert_checkpoints():
274        assert await ubq.get(1) == b"d"
275    with assert_checkpoints():
276        assert await ubq.get() == b"efghi"
277
278    async def putter(data):
279        await wait_all_tasks_blocked()
280        ubq.put(data)
281
282    async def getter(expect):
283        with assert_checkpoints():
284            assert await ubq.get() == expect
285
286    async with _core.open_nursery() as nursery:
287        nursery.start_soon(getter, b"xyz")
288        nursery.start_soon(putter, b"xyz")
289
290    # Two gets at the same time -> BusyResourceError
291    with pytest.raises(_core.BusyResourceError):
292        async with _core.open_nursery() as nursery:
293            nursery.start_soon(getter, b"asdf")
294            nursery.start_soon(getter, b"asdf")
295
296    # Closing
297
298    ubq.close()
299    with pytest.raises(_core.ClosedResourceError):
300        ubq.put(b"---")
301
302    assert ubq.get_nowait(10) == b""
303    assert ubq.get_nowait() == b""
304    assert await ubq.get(10) == b""
305    assert await ubq.get() == b""
306
307    # close is idempotent
308    ubq.close()
309
310    # close wakes up blocked getters
311    ubq2 = _UnboundedByteQueue()
312
313    async def closer():
314        await wait_all_tasks_blocked()
315        ubq2.close()
316
317    async with _core.open_nursery() as nursery:
318        nursery.start_soon(getter, b"")
319        nursery.start_soon(closer)
320
321
322async def test_MemorySendStream():
323    mss = MemorySendStream()
324
325    async def do_send_all(data):
326        with assert_checkpoints():
327            await mss.send_all(data)
328
329    await do_send_all(b"123")
330    assert mss.get_data_nowait(1) == b"1"
331    assert mss.get_data_nowait() == b"23"
332
333    with assert_checkpoints():
334        await mss.wait_send_all_might_not_block()
335
336    with pytest.raises(_core.WouldBlock):
337        mss.get_data_nowait()
338    with pytest.raises(_core.WouldBlock):
339        mss.get_data_nowait(10)
340
341    await do_send_all(b"456")
342    with assert_checkpoints():
343        assert await mss.get_data() == b"456"
344
345    # Call send_all twice at once; one should get BusyResourceError and one
346    # should succeed. But we can't let the error propagate, because it might
347    # cause the other to be cancelled before it can finish doing its thing,
348    # and we don't know which one will get the error.
349    resource_busy_count = 0
350
351    async def do_send_all_count_resourcebusy():
352        nonlocal resource_busy_count
353        try:
354            await do_send_all(b"xxx")
355        except _core.BusyResourceError:
356            resource_busy_count += 1
357
358    async with _core.open_nursery() as nursery:
359        nursery.start_soon(do_send_all_count_resourcebusy)
360        nursery.start_soon(do_send_all_count_resourcebusy)
361
362    assert resource_busy_count == 1
363
364    with assert_checkpoints():
365        await mss.aclose()
366
367    assert await mss.get_data() == b"xxx"
368    assert await mss.get_data() == b""
369    with pytest.raises(_core.ClosedResourceError):
370        await do_send_all(b"---")
371
372    # hooks
373
374    assert mss.send_all_hook is None
375    assert mss.wait_send_all_might_not_block_hook is None
376    assert mss.close_hook is None
377
378    record = []
379
380    async def send_all_hook():
381        # hook runs after send_all does its work (can pull data out)
382        assert mss2.get_data_nowait() == b"abc"
383        record.append("send_all_hook")
384
385    async def wait_send_all_might_not_block_hook():
386        record.append("wait_send_all_might_not_block_hook")
387
388    def close_hook():
389        record.append("close_hook")
390
391    mss2 = MemorySendStream(
392        send_all_hook, wait_send_all_might_not_block_hook, close_hook
393    )
394
395    assert mss2.send_all_hook is send_all_hook
396    assert mss2.wait_send_all_might_not_block_hook is wait_send_all_might_not_block_hook
397    assert mss2.close_hook is close_hook
398
399    await mss2.send_all(b"abc")
400    await mss2.wait_send_all_might_not_block()
401    await aclose_forcefully(mss2)
402    mss2.close()
403
404    assert record == [
405        "send_all_hook",
406        "wait_send_all_might_not_block_hook",
407        "close_hook",
408        "close_hook",
409    ]
410
411
412async def test_MemoryReceiveStream():
413    mrs = MemoryReceiveStream()
414
415    async def do_receive_some(max_bytes):
416        with assert_checkpoints():
417            return await mrs.receive_some(max_bytes)
418
419    mrs.put_data(b"abc")
420    assert await do_receive_some(1) == b"a"
421    assert await do_receive_some(10) == b"bc"
422    mrs.put_data(b"abc")
423    assert await do_receive_some(None) == b"abc"
424
425    with pytest.raises(_core.BusyResourceError):
426        async with _core.open_nursery() as nursery:
427            nursery.start_soon(do_receive_some, 10)
428            nursery.start_soon(do_receive_some, 10)
429
430    assert mrs.receive_some_hook is None
431
432    mrs.put_data(b"def")
433    mrs.put_eof()
434    mrs.put_eof()
435
436    assert await do_receive_some(10) == b"def"
437    assert await do_receive_some(10) == b""
438    assert await do_receive_some(10) == b""
439
440    with pytest.raises(_core.ClosedResourceError):
441        mrs.put_data(b"---")
442
443    async def receive_some_hook():
444        mrs2.put_data(b"xxx")
445
446    record = []
447
448    def close_hook():
449        record.append("closed")
450
451    mrs2 = MemoryReceiveStream(receive_some_hook, close_hook)
452    assert mrs2.receive_some_hook is receive_some_hook
453    assert mrs2.close_hook is close_hook
454
455    mrs2.put_data(b"yyy")
456    assert await mrs2.receive_some(10) == b"yyyxxx"
457    assert await mrs2.receive_some(10) == b"xxx"
458    assert await mrs2.receive_some(10) == b"xxx"
459
460    mrs2.put_data(b"zzz")
461    mrs2.receive_some_hook = None
462    assert await mrs2.receive_some(10) == b"zzz"
463
464    mrs2.put_data(b"lost on close")
465    with assert_checkpoints():
466        await mrs2.aclose()
467    assert record == ["closed"]
468
469    with pytest.raises(_core.ClosedResourceError):
470        await mrs2.receive_some(10)
471
472
473async def test_MemoryRecvStream_closing():
474    mrs = MemoryReceiveStream()
475    # close with no pending data
476    mrs.close()
477    with pytest.raises(_core.ClosedResourceError):
478        assert await mrs.receive_some(10) == b""
479    # repeated closes ok
480    mrs.close()
481    # put_data now fails
482    with pytest.raises(_core.ClosedResourceError):
483        mrs.put_data(b"123")
484
485    mrs2 = MemoryReceiveStream()
486    # close with pending data
487    mrs2.put_data(b"xyz")
488    mrs2.close()
489    with pytest.raises(_core.ClosedResourceError):
490        await mrs2.receive_some(10)
491
492
493async def test_memory_stream_pump():
494    mss = MemorySendStream()
495    mrs = MemoryReceiveStream()
496
497    # no-op if no data present
498    memory_stream_pump(mss, mrs)
499
500    await mss.send_all(b"123")
501    memory_stream_pump(mss, mrs)
502    assert await mrs.receive_some(10) == b"123"
503
504    await mss.send_all(b"456")
505    assert memory_stream_pump(mss, mrs, max_bytes=1)
506    assert await mrs.receive_some(10) == b"4"
507    assert memory_stream_pump(mss, mrs, max_bytes=1)
508    assert memory_stream_pump(mss, mrs, max_bytes=1)
509    assert not memory_stream_pump(mss, mrs, max_bytes=1)
510    assert await mrs.receive_some(10) == b"56"
511
512    mss.close()
513    memory_stream_pump(mss, mrs)
514    assert await mrs.receive_some(10) == b""
515
516
517async def test_memory_stream_one_way_pair():
518    s, r = memory_stream_one_way_pair()
519    assert s.send_all_hook is not None
520    assert s.wait_send_all_might_not_block_hook is None
521    assert s.close_hook is not None
522    assert r.receive_some_hook is None
523    await s.send_all(b"123")
524    assert await r.receive_some(10) == b"123"
525
526    async def receiver(expected):
527        assert await r.receive_some(10) == expected
528
529    # This fails if we pump on r.receive_some_hook; we need to pump on s.send_all_hook
530    async with _core.open_nursery() as nursery:
531        nursery.start_soon(receiver, b"abc")
532        await wait_all_tasks_blocked()
533        await s.send_all(b"abc")
534
535    # And this fails if we don't pump from close_hook
536    async with _core.open_nursery() as nursery:
537        nursery.start_soon(receiver, b"")
538        await wait_all_tasks_blocked()
539        await s.aclose()
540
541    s, r = memory_stream_one_way_pair()
542
543    async with _core.open_nursery() as nursery:
544        nursery.start_soon(receiver, b"")
545        await wait_all_tasks_blocked()
546        s.close()
547
548    s, r = memory_stream_one_way_pair()
549
550    old = s.send_all_hook
551    s.send_all_hook = None
552    await s.send_all(b"456")
553
554    async def cancel_after_idle(nursery):
555        await wait_all_tasks_blocked()
556        nursery.cancel_scope.cancel()
557
558    async def check_for_cancel():
559        with pytest.raises(_core.Cancelled):
560            # This should block forever... or until cancelled. Even though we
561            # sent some data on the send stream.
562            await r.receive_some(10)
563
564    async with _core.open_nursery() as nursery:
565        nursery.start_soon(cancel_after_idle, nursery)
566        nursery.start_soon(check_for_cancel)
567
568    s.send_all_hook = old
569    await s.send_all(b"789")
570    assert await r.receive_some(10) == b"456789"
571
572
573async def test_memory_stream_pair():
574    a, b = memory_stream_pair()
575    await a.send_all(b"123")
576    await b.send_all(b"abc")
577    assert await b.receive_some(10) == b"123"
578    assert await a.receive_some(10) == b"abc"
579
580    await a.send_eof()
581    assert await b.receive_some(10) == b""
582
583    async def sender():
584        await wait_all_tasks_blocked()
585        await b.send_all(b"xyz")
586
587    async def receiver():
588        assert await a.receive_some(10) == b"xyz"
589
590    async with _core.open_nursery() as nursery:
591        nursery.start_soon(receiver)
592        nursery.start_soon(sender)
593
594
595async def test_memory_streams_with_generic_tests():
596    async def one_way_stream_maker():
597        return memory_stream_one_way_pair()
598
599    await check_one_way_stream(one_way_stream_maker, None)
600
601    async def half_closeable_stream_maker():
602        return memory_stream_pair()
603
604    await check_half_closeable_stream(half_closeable_stream_maker, None)
605
606
607async def test_lockstep_streams_with_generic_tests():
608    async def one_way_stream_maker():
609        return lockstep_stream_one_way_pair()
610
611    await check_one_way_stream(one_way_stream_maker, one_way_stream_maker)
612
613    async def two_way_stream_maker():
614        return lockstep_stream_pair()
615
616    await check_two_way_stream(two_way_stream_maker, two_way_stream_maker)
617
618
619async def test_open_stream_to_socket_listener():
620    async def check(listener):
621        async with listener:
622            client_stream = await open_stream_to_socket_listener(listener)
623            async with client_stream:
624                server_stream = await listener.accept()
625                async with server_stream:
626                    await client_stream.send_all(b"x")
627                    await server_stream.receive_some(1) == b"x"
628
629    # Listener bound to localhost
630    sock = tsocket.socket()
631    await sock.bind(("127.0.0.1", 0))
632    sock.listen(10)
633    await check(SocketListener(sock))
634
635    # Listener bound to IPv4 wildcard (needs special handling)
636    sock = tsocket.socket()
637    await sock.bind(("0.0.0.0", 0))
638    sock.listen(10)
639    await check(SocketListener(sock))
640
641    if can_bind_ipv6:
642        # Listener bound to IPv6 wildcard (needs special handling)
643        sock = tsocket.socket(family=tsocket.AF_INET6)
644        await sock.bind(("::", 0))
645        sock.listen(10)
646        await check(SocketListener(sock))
647
648    if hasattr(tsocket, "AF_UNIX"):
649        # Listener bound to Unix-domain socket
650        sock = tsocket.socket(family=tsocket.AF_UNIX)
651        # can't use pytest's tmpdir; if we try then macOS says "OSError:
652        # AF_UNIX path too long"
653        with tempfile.TemporaryDirectory() as tmpdir:
654            path = "{}/sock".format(tmpdir)
655            await sock.bind(path)
656            sock.listen(10)
657            await check(SocketListener(sock))
658