1 #![warn(rust_2018_idioms)]
2 #![cfg(feature = "full")]
3 
4 use tokio::io::{AsyncReadExt, AsyncWriteExt, Interest};
5 use tokio::net::{TcpListener, TcpStream};
6 use tokio::try_join;
7 use tokio_test::task;
8 use tokio_test::{assert_ok, assert_pending, assert_ready_ok};
9 
10 use std::io;
11 use std::task::Poll;
12 use std::time::Duration;
13 
14 use futures::future::poll_fn;
15 
16 #[tokio::test]
set_linger()17 async fn set_linger() {
18     let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
19 
20     let stream = TcpStream::connect(listener.local_addr().unwrap())
21         .await
22         .unwrap();
23 
24     assert_ok!(stream.set_linger(Some(Duration::from_secs(1))));
25     assert_eq!(stream.linger().unwrap().unwrap().as_secs(), 1);
26 
27     assert_ok!(stream.set_linger(None));
28     assert!(stream.linger().unwrap().is_none());
29 }
30 
31 #[tokio::test]
try_read_write()32 async fn try_read_write() {
33     const DATA: &[u8] = b"this is some data to write to the socket";
34 
35     // Create listener
36     let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
37 
38     // Create socket pair
39     let client = TcpStream::connect(listener.local_addr().unwrap())
40         .await
41         .unwrap();
42     let (server, _) = listener.accept().await.unwrap();
43     let mut written = DATA.to_vec();
44 
45     // Track the server receiving data
46     let mut readable = task::spawn(server.readable());
47     assert_pending!(readable.poll());
48 
49     // Write data.
50     client.writable().await.unwrap();
51     assert_eq!(DATA.len(), client.try_write(DATA).unwrap());
52 
53     // The task should be notified
54     while !readable.is_woken() {
55         tokio::task::yield_now().await;
56     }
57 
58     // Fill the write buffer using non-vectored I/O
59     loop {
60         // Still ready
61         let mut writable = task::spawn(client.writable());
62         assert_ready_ok!(writable.poll());
63 
64         match client.try_write(DATA) {
65             Ok(n) => written.extend(&DATA[..n]),
66             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
67                 break;
68             }
69             Err(e) => panic!("error = {:?}", e),
70         }
71     }
72 
73     {
74         // Write buffer full
75         let mut writable = task::spawn(client.writable());
76         assert_pending!(writable.poll());
77 
78         // Drain the socket from the server end using non-vectored I/O
79         let mut read = vec![0; written.len()];
80         let mut i = 0;
81 
82         while i < read.len() {
83             server.readable().await.unwrap();
84 
85             match server.try_read(&mut read[i..]) {
86                 Ok(n) => i += n,
87                 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
88                 Err(e) => panic!("error = {:?}", e),
89             }
90         }
91 
92         assert_eq!(read, written);
93     }
94 
95     written.clear();
96     client.writable().await.unwrap();
97 
98     // Fill the write buffer using vectored I/O
99     let data_bufs: Vec<_> = DATA.chunks(10).map(io::IoSlice::new).collect();
100     loop {
101         // Still ready
102         let mut writable = task::spawn(client.writable());
103         assert_ready_ok!(writable.poll());
104 
105         match client.try_write_vectored(&data_bufs) {
106             Ok(n) => written.extend(&DATA[..n]),
107             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
108                 break;
109             }
110             Err(e) => panic!("error = {:?}", e),
111         }
112     }
113 
114     {
115         // Write buffer full
116         let mut writable = task::spawn(client.writable());
117         assert_pending!(writable.poll());
118 
119         // Drain the socket from the server end using vectored I/O
120         let mut read = vec![0; written.len()];
121         let mut i = 0;
122 
123         while i < read.len() {
124             server.readable().await.unwrap();
125 
126             let mut bufs: Vec<_> = read[i..]
127                 .chunks_mut(0x10000)
128                 .map(io::IoSliceMut::new)
129                 .collect();
130             match server.try_read_vectored(&mut bufs) {
131                 Ok(n) => i += n,
132                 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
133                 Err(e) => panic!("error = {:?}", e),
134             }
135         }
136 
137         assert_eq!(read, written);
138     }
139 
140     // Now, we listen for shutdown
141     drop(client);
142 
143     loop {
144         let ready = server.ready(Interest::READABLE).await.unwrap();
145 
146         if ready.is_read_closed() {
147             return;
148         } else {
149             tokio::task::yield_now().await;
150         }
151     }
152 }
153 
154 #[test]
buffer_not_included_in_future()155 fn buffer_not_included_in_future() {
156     use std::mem;
157 
158     const N: usize = 4096;
159 
160     let fut = async {
161         let stream = TcpStream::connect("127.0.0.1:8080").await.unwrap();
162 
163         loop {
164             stream.readable().await.unwrap();
165 
166             let mut buf = [0; N];
167             let n = stream.try_read(&mut buf[..]).unwrap();
168 
169             if n == 0 {
170                 break;
171             }
172         }
173     };
174 
175     let n = mem::size_of_val(&fut);
176     assert!(n < 1000);
177 }
178 
179 macro_rules! assert_readable_by_polling {
180     ($stream:expr) => {
181         assert_ok!(poll_fn(|cx| $stream.poll_read_ready(cx)).await);
182     };
183 }
184 
185 macro_rules! assert_not_readable_by_polling {
186     ($stream:expr) => {
187         poll_fn(|cx| {
188             assert_pending!($stream.poll_read_ready(cx));
189             Poll::Ready(())
190         })
191         .await;
192     };
193 }
194 
195 macro_rules! assert_writable_by_polling {
196     ($stream:expr) => {
197         assert_ok!(poll_fn(|cx| $stream.poll_write_ready(cx)).await);
198     };
199 }
200 
201 macro_rules! assert_not_writable_by_polling {
202     ($stream:expr) => {
203         poll_fn(|cx| {
204             assert_pending!($stream.poll_write_ready(cx));
205             Poll::Ready(())
206         })
207         .await;
208     };
209 }
210 
211 #[tokio::test]
poll_read_ready()212 async fn poll_read_ready() {
213     let (mut client, mut server) = create_pair().await;
214 
215     // Initial state - not readable.
216     assert_not_readable_by_polling!(server);
217 
218     // There is data in the buffer - readable.
219     assert_ok!(client.write_all(b"ping").await);
220     assert_readable_by_polling!(server);
221 
222     // Readable until calls to `poll_read` return `Poll::Pending`.
223     let mut buf = [0u8; 4];
224     assert_ok!(server.read_exact(&mut buf).await);
225     assert_readable_by_polling!(server);
226     read_until_pending(&mut server);
227     assert_not_readable_by_polling!(server);
228 
229     // Detect the client disconnect.
230     drop(client);
231     assert_readable_by_polling!(server);
232 }
233 
234 #[tokio::test]
poll_write_ready()235 async fn poll_write_ready() {
236     let (mut client, server) = create_pair().await;
237 
238     // Initial state - writable.
239     assert_writable_by_polling!(client);
240 
241     // No space to write - not writable.
242     write_until_pending(&mut client);
243     assert_not_writable_by_polling!(client);
244 
245     // Detect the server disconnect.
246     drop(server);
247     assert_writable_by_polling!(client);
248 }
249 
create_pair() -> (TcpStream, TcpStream)250 async fn create_pair() -> (TcpStream, TcpStream) {
251     let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await);
252     let addr = assert_ok!(listener.local_addr());
253     let (client, (server, _)) = assert_ok!(try_join!(TcpStream::connect(&addr), listener.accept()));
254     (client, server)
255 }
256 
read_until_pending(stream: &mut TcpStream)257 fn read_until_pending(stream: &mut TcpStream) {
258     let mut buf = vec![0u8; 1024 * 1024];
259     loop {
260         match stream.try_read(&mut buf) {
261             Ok(_) => (),
262             Err(err) => {
263                 assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
264                 break;
265             }
266         }
267     }
268 }
269 
write_until_pending(stream: &mut TcpStream)270 fn write_until_pending(stream: &mut TcpStream) {
271     let buf = vec![0u8; 1024 * 1024];
272     loop {
273         match stream.try_write(&buf) {
274             Ok(_) => (),
275             Err(err) => {
276                 assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
277                 break;
278             }
279         }
280     }
281 }
282 
283 #[tokio::test]
try_read_buf()284 async fn try_read_buf() {
285     const DATA: &[u8] = b"this is some data to write to the socket";
286 
287     // Create listener
288     let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
289 
290     // Create socket pair
291     let client = TcpStream::connect(listener.local_addr().unwrap())
292         .await
293         .unwrap();
294     let (server, _) = listener.accept().await.unwrap();
295     let mut written = DATA.to_vec();
296 
297     // Track the server receiving data
298     let mut readable = task::spawn(server.readable());
299     assert_pending!(readable.poll());
300 
301     // Write data.
302     client.writable().await.unwrap();
303     assert_eq!(DATA.len(), client.try_write(DATA).unwrap());
304 
305     // The task should be notified
306     while !readable.is_woken() {
307         tokio::task::yield_now().await;
308     }
309 
310     // Fill the write buffer
311     loop {
312         // Still ready
313         let mut writable = task::spawn(client.writable());
314         assert_ready_ok!(writable.poll());
315 
316         match client.try_write(DATA) {
317             Ok(n) => written.extend(&DATA[..n]),
318             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
319                 break;
320             }
321             Err(e) => panic!("error = {:?}", e),
322         }
323     }
324 
325     {
326         // Write buffer full
327         let mut writable = task::spawn(client.writable());
328         assert_pending!(writable.poll());
329 
330         // Drain the socket from the server end
331         let mut read = Vec::with_capacity(written.len());
332         let mut i = 0;
333 
334         while i < read.capacity() {
335             server.readable().await.unwrap();
336 
337             match server.try_read_buf(&mut read) {
338                 Ok(n) => i += n,
339                 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
340                 Err(e) => panic!("error = {:?}", e),
341             }
342         }
343 
344         assert_eq!(read, written);
345     }
346 
347     // Now, we listen for shutdown
348     drop(client);
349 
350     loop {
351         let ready = server.ready(Interest::READABLE).await.unwrap();
352 
353         if ready.is_read_closed() {
354             return;
355         } else {
356             tokio::task::yield_now().await;
357         }
358     }
359 }
360