1import pytest
2
3import socket as stdlib_socket
4import select
5import random
6import errno
7from contextlib import suppress
8
9from ... import _core
10from ...testing import wait_all_tasks_blocked, Sequencer, assert_checkpoints
11import trio
12
13# Cross-platform tests for IO handling
14
15
16def fill_socket(sock):
17    try:
18        while True:
19            sock.send(b"x" * 65536)
20    except BlockingIOError:
21        pass
22
23
24def drain_socket(sock):
25    try:
26        while True:
27            sock.recv(65536)
28    except BlockingIOError:
29        pass
30
31
32@pytest.fixture
33def socketpair():
34    pair = stdlib_socket.socketpair()
35    for sock in pair:
36        sock.setblocking(False)
37    yield pair
38    for sock in pair:
39        sock.close()
40
41
42def using_fileno(fn):
43    def fileno_wrapper(fileobj):
44        return fn(fileobj.fileno())
45
46    name = "<{} on fileno>".format(fn.__name__)
47    fileno_wrapper.__name__ = fileno_wrapper.__qualname__ = name
48    return fileno_wrapper
49
50
51wait_readable_options = [trio.lowlevel.wait_readable]
52wait_writable_options = [trio.lowlevel.wait_writable]
53notify_closing_options = [trio.lowlevel.notify_closing]
54
55for options_list in [
56    wait_readable_options,
57    wait_writable_options,
58    notify_closing_options,
59]:
60    options_list += [using_fileno(f) for f in options_list]
61
62# Decorators that feed in different settings for wait_readable / wait_writable
63# / notify_closing.
64# Note that if you use all three decorators on the same test, it will run all
65# N**3 *combinations*
66read_socket_test = pytest.mark.parametrize(
67    "wait_readable", wait_readable_options, ids=lambda fn: fn.__name__
68)
69write_socket_test = pytest.mark.parametrize(
70    "wait_writable", wait_writable_options, ids=lambda fn: fn.__name__
71)
72notify_closing_test = pytest.mark.parametrize(
73    "notify_closing", notify_closing_options, ids=lambda fn: fn.__name__
74)
75
76
77# XX These tests are all a bit dicey because they can't distinguish between
78# wait_on_{read,writ}able blocking the way it should, versus blocking
79# momentarily and then immediately resuming.
80@read_socket_test
81@write_socket_test
82async def test_wait_basic(socketpair, wait_readable, wait_writable):
83    a, b = socketpair
84
85    # They start out writable()
86    with assert_checkpoints():
87        await wait_writable(a)
88
89    # But readable() blocks until data arrives
90    record = []
91
92    async def block_on_read():
93        try:
94            with assert_checkpoints():
95                await wait_readable(a)
96        except _core.Cancelled:
97            record.append("cancelled")
98        else:
99            record.append("readable")
100            assert a.recv(10) == b"x"
101
102    async with _core.open_nursery() as nursery:
103        nursery.start_soon(block_on_read)
104        await wait_all_tasks_blocked()
105        assert record == []
106        b.send(b"x")
107
108    fill_socket(a)
109
110    # Now writable will block, but readable won't
111    with assert_checkpoints():
112        await wait_readable(b)
113    record = []
114
115    async def block_on_write():
116        try:
117            with assert_checkpoints():
118                await wait_writable(a)
119        except _core.Cancelled:
120            record.append("cancelled")
121        else:
122            record.append("writable")
123
124    async with _core.open_nursery() as nursery:
125        nursery.start_soon(block_on_write)
126        await wait_all_tasks_blocked()
127        assert record == []
128        drain_socket(b)
129
130    # check cancellation
131    record = []
132    async with _core.open_nursery() as nursery:
133        nursery.start_soon(block_on_read)
134        await wait_all_tasks_blocked()
135        nursery.cancel_scope.cancel()
136    assert record == ["cancelled"]
137
138    fill_socket(a)
139    record = []
140    async with _core.open_nursery() as nursery:
141        nursery.start_soon(block_on_write)
142        await wait_all_tasks_blocked()
143        nursery.cancel_scope.cancel()
144    assert record == ["cancelled"]
145
146
147@read_socket_test
148async def test_double_read(socketpair, wait_readable):
149    a, b = socketpair
150
151    # You can't have two tasks trying to read from a socket at the same time
152    async with _core.open_nursery() as nursery:
153        nursery.start_soon(wait_readable, a)
154        await wait_all_tasks_blocked()
155        with pytest.raises(_core.BusyResourceError):
156            await wait_readable(a)
157        nursery.cancel_scope.cancel()
158
159
160@write_socket_test
161async def test_double_write(socketpair, wait_writable):
162    a, b = socketpair
163
164    # You can't have two tasks trying to write to a socket at the same time
165    fill_socket(a)
166    async with _core.open_nursery() as nursery:
167        nursery.start_soon(wait_writable, a)
168        await wait_all_tasks_blocked()
169        with pytest.raises(_core.BusyResourceError):
170            await wait_writable(a)
171        nursery.cancel_scope.cancel()
172
173
174@read_socket_test
175@write_socket_test
176@notify_closing_test
177async def test_interrupted_by_close(
178    socketpair, wait_readable, wait_writable, notify_closing
179):
180    a, b = socketpair
181
182    async def reader():
183        with pytest.raises(_core.ClosedResourceError):
184            await wait_readable(a)
185
186    async def writer():
187        with pytest.raises(_core.ClosedResourceError):
188            await wait_writable(a)
189
190    fill_socket(a)
191
192    async with _core.open_nursery() as nursery:
193        nursery.start_soon(reader)
194        nursery.start_soon(writer)
195        await wait_all_tasks_blocked()
196        notify_closing(a)
197
198
199@read_socket_test
200@write_socket_test
201async def test_socket_simultaneous_read_write(socketpair, wait_readable, wait_writable):
202    record = []
203
204    async def r_task(sock):
205        await wait_readable(sock)
206        record.append("r_task")
207
208    async def w_task(sock):
209        await wait_writable(sock)
210        record.append("w_task")
211
212    a, b = socketpair
213    fill_socket(a)
214    async with _core.open_nursery() as nursery:
215        nursery.start_soon(r_task, a)
216        nursery.start_soon(w_task, a)
217        await wait_all_tasks_blocked()
218        assert record == []
219        b.send(b"x")
220        await wait_all_tasks_blocked()
221        assert record == ["r_task"]
222        drain_socket(b)
223        await wait_all_tasks_blocked()
224        assert record == ["r_task", "w_task"]
225
226
227@read_socket_test
228@write_socket_test
229async def test_socket_actual_streaming(socketpair, wait_readable, wait_writable):
230    a, b = socketpair
231
232    # Use a small send buffer on one of the sockets to increase the chance of
233    # getting partial writes
234    a.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_SNDBUF, 10000)
235
236    N = 1000000  # 1 megabyte
237    MAX_CHUNK = 65536
238
239    results = {}
240
241    async def sender(sock, seed, key):
242        r = random.Random(seed)
243        sent = 0
244        while sent < N:
245            print("sent", sent)
246            chunk = bytearray(r.randrange(MAX_CHUNK))
247            while chunk:
248                with assert_checkpoints():
249                    await wait_writable(sock)
250                this_chunk_size = sock.send(chunk)
251                sent += this_chunk_size
252                del chunk[:this_chunk_size]
253        sock.shutdown(stdlib_socket.SHUT_WR)
254        results[key] = sent
255
256    async def receiver(sock, key):
257        received = 0
258        while True:
259            print("received", received)
260            with assert_checkpoints():
261                await wait_readable(sock)
262            this_chunk_size = len(sock.recv(MAX_CHUNK))
263            if not this_chunk_size:
264                break
265            received += this_chunk_size
266        results[key] = received
267
268    async with _core.open_nursery() as nursery:
269        nursery.start_soon(sender, a, 0, "send_a")
270        nursery.start_soon(sender, b, 1, "send_b")
271        nursery.start_soon(receiver, a, "recv_a")
272        nursery.start_soon(receiver, b, "recv_b")
273
274    assert results["send_a"] == results["recv_b"]
275    assert results["send_b"] == results["recv_a"]
276
277
278async def test_notify_closing_on_invalid_object():
279    # It should either be a no-op (generally on Unix, where we don't know
280    # which fds are valid), or an OSError (on Windows, where we currently only
281    # support sockets, so we have to do some validation to figure out whether
282    # it's a socket or a regular handle).
283    got_oserror = False
284    got_no_error = False
285    try:
286        trio.lowlevel.notify_closing(-1)
287    except OSError:
288        got_oserror = True
289    else:
290        got_no_error = True
291    assert got_oserror or got_no_error
292
293
294async def test_wait_on_invalid_object():
295    # We definitely want to raise an error everywhere if you pass in an
296    # invalid fd to wait_*
297    for wait in [trio.lowlevel.wait_readable, trio.lowlevel.wait_writable]:
298        with stdlib_socket.socket() as s:
299            fileno = s.fileno()
300        # We just closed the socket and don't do anything else in between, so
301        # we can be confident that the fileno hasn't be reassigned.
302        with pytest.raises(OSError):
303            await wait(fileno)
304
305
306async def test_io_manager_statistics():
307    def check(*, expected_readers, expected_writers):
308        statistics = _core.current_statistics()
309        print(statistics)
310        iostats = statistics.io_statistics
311        if iostats.backend in ["epoll", "windows"]:
312            assert iostats.tasks_waiting_read == expected_readers
313            assert iostats.tasks_waiting_write == expected_writers
314        else:
315            assert iostats.backend == "kqueue"
316            assert iostats.tasks_waiting == expected_readers + expected_writers
317
318    a1, b1 = stdlib_socket.socketpair()
319    a2, b2 = stdlib_socket.socketpair()
320    a3, b3 = stdlib_socket.socketpair()
321    for sock in [a1, b1, a2, b2, a3, b3]:
322        sock.setblocking(False)
323    with a1, b1, a2, b2, a3, b3:
324        # let the call_soon_task settle down
325        await wait_all_tasks_blocked()
326
327        # 1 for call_soon_task
328        check(expected_readers=1, expected_writers=0)
329
330        # We want:
331        # - one socket with a writer blocked
332        # - two sockets with a reader blocked
333        # - a socket with both blocked
334        fill_socket(a1)
335        fill_socket(a3)
336        async with _core.open_nursery() as nursery:
337            nursery.start_soon(_core.wait_writable, a1)
338            nursery.start_soon(_core.wait_readable, a2)
339            nursery.start_soon(_core.wait_readable, b2)
340            nursery.start_soon(_core.wait_writable, a3)
341            nursery.start_soon(_core.wait_readable, a3)
342
343            await wait_all_tasks_blocked()
344
345            # +1 for call_soon_task
346            check(expected_readers=3 + 1, expected_writers=2)
347
348            nursery.cancel_scope.cancel()
349
350        # 1 for call_soon_task
351        check(expected_readers=1, expected_writers=0)
352
353
354async def test_can_survive_unnotified_close():
355    # An "unnotified" close is when the user closes an fd/socket/handle
356    # directly, without calling notify_closing first. This should never happen
357    # -- users should call notify_closing before closing things. But, just in
358    # case they don't, we would still like to avoid exploding.
359    #
360    # Acceptable behaviors:
361    # - wait_* never return, but can be cancelled cleanly
362    # - wait_* exit cleanly
363    # - wait_* raise an OSError
364    #
365    # Not acceptable:
366    # - getting stuck in an uncancellable state
367    # - TrioInternalError blowing up the whole run
368    #
369    # This test exercises some tricky "unnotified close" scenarios, to make
370    # sure we get the "acceptable" behaviors.
371
372    async def allow_OSError(async_func, *args):
373        with suppress(OSError):
374            await async_func(*args)
375
376    with stdlib_socket.socket() as s:
377        async with trio.open_nursery() as nursery:
378            nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, s)
379            await wait_all_tasks_blocked()
380            s.close()
381            await wait_all_tasks_blocked()
382            nursery.cancel_scope.cancel()
383
384    # We hit different paths on Windows depending on whether we close the last
385    # handle to the object (which produces a LOCAL_CLOSE notification and
386    # wakes up wait_readable), or only close one of the handles (which leaves
387    # wait_readable pending until cancelled).
388    with stdlib_socket.socket() as s, s.dup() as s2:  # noqa: F841
389        async with trio.open_nursery() as nursery:
390            nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, s)
391            await wait_all_tasks_blocked()
392            s.close()
393            await wait_all_tasks_blocked()
394            nursery.cancel_scope.cancel()
395
396    # A more elaborate case, with two tasks waiting. On windows and epoll,
397    # the two tasks get muxed together onto a single underlying wait
398    # operation. So when they're cancelled, there's a brief moment where one
399    # of the tasks is cancelled but the other isn't, so we try to re-issue the
400    # underlying wait operation. But here, the handle we were going to use to
401    # do that has been pulled out from under our feet... so test that we can
402    # survive this.
403    a, b = stdlib_socket.socketpair()
404    with a, b, a.dup() as a2:  # noqa: F841
405        a.setblocking(False)
406        b.setblocking(False)
407        fill_socket(a)
408        async with trio.open_nursery() as nursery:
409            nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, a)
410            nursery.start_soon(allow_OSError, trio.lowlevel.wait_writable, a)
411            await wait_all_tasks_blocked()
412            a.close()
413            nursery.cancel_scope.cancel()
414
415    # A similar case, but now the single-task-wakeup happens due to I/O
416    # arriving, not a cancellation, so the operation gets re-issued from
417    # handle_io context rather than abort context.
418    a, b = stdlib_socket.socketpair()
419    with a, b, a.dup() as a2:  # noqa: F841
420        print("a={}, b={}, a2={}".format(a.fileno(), b.fileno(), a2.fileno()))
421        a.setblocking(False)
422        b.setblocking(False)
423        fill_socket(a)
424        e = trio.Event()
425
426        # We want to wait for the kernel to process the wakeup on 'a', if any.
427        # But depending on the platform, we might not get a wakeup on 'a'. So
428        # we put one task to sleep waiting on 'a', and we put a second task to
429        # sleep waiting on 'a2', with the idea that the 'a2' notification will
430        # definitely arrive, and when it does then we can assume that whatever
431        # notification was going to arrive for 'a' has also arrived.
432        async def wait_readable_a2_then_set():
433            await trio.lowlevel.wait_readable(a2)
434            e.set()
435
436        async with trio.open_nursery() as nursery:
437            nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, a)
438            nursery.start_soon(allow_OSError, trio.lowlevel.wait_writable, a)
439            nursery.start_soon(wait_readable_a2_then_set)
440            await wait_all_tasks_blocked()
441            a.close()
442            b.send(b"x")
443            # Make sure that the wakeup has been received and everything has
444            # settled before cancelling the wait_writable.
445            await e.wait()
446            await wait_all_tasks_blocked()
447            nursery.cancel_scope.cancel()
448