1 #![cfg(feature = "full")]
2 #![warn(rust_2018_idioms)]
3 #![cfg(unix)]
4 
5 use std::io;
6 use std::task::Poll;
7 
8 use tokio::io::{AsyncReadExt, AsyncWriteExt, Interest};
9 use tokio::net::{UnixListener, UnixStream};
10 use tokio_test::{assert_ok, assert_pending, assert_ready_ok, task};
11 
12 use futures::future::{poll_fn, try_join};
13 
14 #[tokio::test]
accept_read_write() -> std::io::Result<()>15 async fn accept_read_write() -> std::io::Result<()> {
16     let dir = tempfile::Builder::new()
17         .prefix("tokio-uds-tests")
18         .tempdir()
19         .unwrap();
20     let sock_path = dir.path().join("connect.sock");
21 
22     let listener = UnixListener::bind(&sock_path)?;
23 
24     let accept = listener.accept();
25     let connect = UnixStream::connect(&sock_path);
26     let ((mut server, _), mut client) = try_join(accept, connect).await?;
27 
28     // Write to the client. TODO: Switch to write_all.
29     let write_len = client.write(b"hello").await?;
30     assert_eq!(write_len, 5);
31     drop(client);
32     // Read from the server. TODO: Switch to read_to_end.
33     let mut buf = [0u8; 5];
34     server.read_exact(&mut buf).await?;
35     assert_eq!(&buf, b"hello");
36     let len = server.read(&mut buf).await?;
37     assert_eq!(len, 0);
38     Ok(())
39 }
40 
41 #[tokio::test]
shutdown() -> std::io::Result<()>42 async fn shutdown() -> std::io::Result<()> {
43     let dir = tempfile::Builder::new()
44         .prefix("tokio-uds-tests")
45         .tempdir()
46         .unwrap();
47     let sock_path = dir.path().join("connect.sock");
48 
49     let listener = UnixListener::bind(&sock_path)?;
50 
51     let accept = listener.accept();
52     let connect = UnixStream::connect(&sock_path);
53     let ((mut server, _), mut client) = try_join(accept, connect).await?;
54 
55     // Shut down the client
56     AsyncWriteExt::shutdown(&mut client).await?;
57     // Read from the server should return 0 to indicate the channel has been closed.
58     let mut buf = [0u8; 1];
59     let n = server.read(&mut buf).await?;
60     assert_eq!(n, 0);
61     Ok(())
62 }
63 
64 #[tokio::test]
try_read_write() -> std::io::Result<()>65 async fn try_read_write() -> std::io::Result<()> {
66     let msg = b"hello world";
67 
68     let dir = tempfile::tempdir()?;
69     let bind_path = dir.path().join("bind.sock");
70 
71     // Create listener
72     let listener = UnixListener::bind(&bind_path)?;
73 
74     // Create socket pair
75     let client = UnixStream::connect(&bind_path).await?;
76 
77     let (server, _) = listener.accept().await?;
78     let mut written = msg.to_vec();
79 
80     // Track the server receiving data
81     let mut readable = task::spawn(server.readable());
82     assert_pending!(readable.poll());
83 
84     // Write data.
85     client.writable().await?;
86     assert_eq!(msg.len(), client.try_write(msg)?);
87 
88     // The task should be notified
89     while !readable.is_woken() {
90         tokio::task::yield_now().await;
91     }
92 
93     // Fill the write buffer using non-vectored I/O
94     loop {
95         // Still ready
96         let mut writable = task::spawn(client.writable());
97         assert_ready_ok!(writable.poll());
98 
99         match client.try_write(msg) {
100             Ok(n) => written.extend(&msg[..n]),
101             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
102                 break;
103             }
104             Err(e) => panic!("error = {:?}", e),
105         }
106     }
107 
108     {
109         // Write buffer full
110         let mut writable = task::spawn(client.writable());
111         assert_pending!(writable.poll());
112 
113         // Drain the socket from the server end using non-vectored I/O
114         let mut read = vec![0; written.len()];
115         let mut i = 0;
116 
117         while i < read.len() {
118             server.readable().await?;
119 
120             match server.try_read(&mut read[i..]) {
121                 Ok(n) => i += n,
122                 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
123                 Err(e) => panic!("error = {:?}", e),
124             }
125         }
126 
127         assert_eq!(read, written);
128     }
129 
130     written.clear();
131     client.writable().await.unwrap();
132 
133     // Fill the write buffer using vectored I/O
134     let msg_bufs: Vec<_> = msg.chunks(3).map(io::IoSlice::new).collect();
135     loop {
136         // Still ready
137         let mut writable = task::spawn(client.writable());
138         assert_ready_ok!(writable.poll());
139 
140         match client.try_write_vectored(&msg_bufs) {
141             Ok(n) => written.extend(&msg[..n]),
142             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
143                 break;
144             }
145             Err(e) => panic!("error = {:?}", e),
146         }
147     }
148 
149     {
150         // Write buffer full
151         let mut writable = task::spawn(client.writable());
152         assert_pending!(writable.poll());
153 
154         // Drain the socket from the server end using vectored I/O
155         let mut read = vec![0; written.len()];
156         let mut i = 0;
157 
158         while i < read.len() {
159             server.readable().await?;
160 
161             let mut bufs: Vec<_> = read[i..]
162                 .chunks_mut(0x10000)
163                 .map(io::IoSliceMut::new)
164                 .collect();
165             match server.try_read_vectored(&mut bufs) {
166                 Ok(n) => i += n,
167                 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
168                 Err(e) => panic!("error = {:?}", e),
169             }
170         }
171 
172         assert_eq!(read, written);
173     }
174 
175     // Now, we listen for shutdown
176     drop(client);
177 
178     loop {
179         let ready = server.ready(Interest::READABLE).await?;
180 
181         if ready.is_read_closed() {
182             break;
183         } else {
184             tokio::task::yield_now().await;
185         }
186     }
187 
188     Ok(())
189 }
190 
create_pair() -> (UnixStream, UnixStream)191 async fn create_pair() -> (UnixStream, UnixStream) {
192     let dir = assert_ok!(tempfile::tempdir());
193     let bind_path = dir.path().join("bind.sock");
194 
195     let listener = assert_ok!(UnixListener::bind(&bind_path));
196 
197     let accept = listener.accept();
198     let connect = UnixStream::connect(&bind_path);
199     let ((server, _), client) = assert_ok!(try_join(accept, connect).await);
200 
201     (client, server)
202 }
203 
204 macro_rules! assert_readable_by_polling {
205     ($stream:expr) => {
206         assert_ok!(poll_fn(|cx| $stream.poll_read_ready(cx)).await);
207     };
208 }
209 
210 macro_rules! assert_not_readable_by_polling {
211     ($stream:expr) => {
212         poll_fn(|cx| {
213             assert_pending!($stream.poll_read_ready(cx));
214             Poll::Ready(())
215         })
216         .await;
217     };
218 }
219 
220 macro_rules! assert_writable_by_polling {
221     ($stream:expr) => {
222         assert_ok!(poll_fn(|cx| $stream.poll_write_ready(cx)).await);
223     };
224 }
225 
226 macro_rules! assert_not_writable_by_polling {
227     ($stream:expr) => {
228         poll_fn(|cx| {
229             assert_pending!($stream.poll_write_ready(cx));
230             Poll::Ready(())
231         })
232         .await;
233     };
234 }
235 
236 #[tokio::test]
poll_read_ready()237 async fn poll_read_ready() {
238     let (mut client, mut server) = create_pair().await;
239 
240     // Initial state - not readable.
241     assert_not_readable_by_polling!(server);
242 
243     // There is data in the buffer - readable.
244     assert_ok!(client.write_all(b"ping").await);
245     assert_readable_by_polling!(server);
246 
247     // Readable until calls to `poll_read` return `Poll::Pending`.
248     let mut buf = [0u8; 4];
249     assert_ok!(server.read_exact(&mut buf).await);
250     assert_readable_by_polling!(server);
251     read_until_pending(&mut server);
252     assert_not_readable_by_polling!(server);
253 
254     // Detect the client disconnect.
255     drop(client);
256     assert_readable_by_polling!(server);
257 }
258 
259 #[tokio::test]
poll_write_ready()260 async fn poll_write_ready() {
261     let (mut client, server) = create_pair().await;
262 
263     // Initial state - writable.
264     assert_writable_by_polling!(client);
265 
266     // No space to write - not writable.
267     write_until_pending(&mut client);
268     assert_not_writable_by_polling!(client);
269 
270     // Detect the server disconnect.
271     drop(server);
272     assert_writable_by_polling!(client);
273 }
274 
read_until_pending(stream: &mut UnixStream)275 fn read_until_pending(stream: &mut UnixStream) {
276     let mut buf = vec![0u8; 1024 * 1024];
277     loop {
278         match stream.try_read(&mut buf) {
279             Ok(_) => (),
280             Err(err) => {
281                 assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
282                 break;
283             }
284         }
285     }
286 }
287 
write_until_pending(stream: &mut UnixStream)288 fn write_until_pending(stream: &mut UnixStream) {
289     let buf = vec![0u8; 1024 * 1024];
290     loop {
291         match stream.try_write(&buf) {
292             Ok(_) => (),
293             Err(err) => {
294                 assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
295                 break;
296             }
297         }
298     }
299 }
300 
301 #[tokio::test]
try_read_buf() -> std::io::Result<()>302 async fn try_read_buf() -> std::io::Result<()> {
303     let msg = b"hello world";
304 
305     let dir = tempfile::tempdir()?;
306     let bind_path = dir.path().join("bind.sock");
307 
308     // Create listener
309     let listener = UnixListener::bind(&bind_path)?;
310 
311     // Create socket pair
312     let client = UnixStream::connect(&bind_path).await?;
313 
314     let (server, _) = listener.accept().await?;
315     let mut written = msg.to_vec();
316 
317     // Track the server receiving data
318     let mut readable = task::spawn(server.readable());
319     assert_pending!(readable.poll());
320 
321     // Write data.
322     client.writable().await?;
323     assert_eq!(msg.len(), client.try_write(msg)?);
324 
325     // The task should be notified
326     while !readable.is_woken() {
327         tokio::task::yield_now().await;
328     }
329 
330     // Fill the write buffer
331     loop {
332         // Still ready
333         let mut writable = task::spawn(client.writable());
334         assert_ready_ok!(writable.poll());
335 
336         match client.try_write(msg) {
337             Ok(n) => written.extend(&msg[..n]),
338             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
339                 break;
340             }
341             Err(e) => panic!("error = {:?}", e),
342         }
343     }
344 
345     {
346         // Write buffer full
347         let mut writable = task::spawn(client.writable());
348         assert_pending!(writable.poll());
349 
350         // Drain the socket from the server end
351         let mut read = Vec::with_capacity(written.len());
352         let mut i = 0;
353 
354         while i < read.capacity() {
355             server.readable().await?;
356 
357             match server.try_read_buf(&mut read) {
358                 Ok(n) => i += n,
359                 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
360                 Err(e) => panic!("error = {:?}", e),
361             }
362         }
363 
364         assert_eq!(read, written);
365     }
366 
367     // Now, we listen for shutdown
368     drop(client);
369 
370     loop {
371         let ready = server.ready(Interest::READABLE).await?;
372 
373         if ready.is_read_closed() {
374             break;
375         } else {
376             tokio::task::yield_now().await;
377         }
378     }
379 
380     Ok(())
381 }
382 
383 // https://github.com/tokio-rs/tokio/issues/3879
384 #[tokio::test]
385 #[cfg(not(target_os = "macos"))]
epollhup() -> io::Result<()>386 async fn epollhup() -> io::Result<()> {
387     let dir = tempfile::Builder::new()
388         .prefix("tokio-uds-tests")
389         .tempdir()
390         .unwrap();
391     let sock_path = dir.path().join("connect.sock");
392 
393     let listener = UnixListener::bind(&sock_path)?;
394     let connect = UnixStream::connect(&sock_path);
395     tokio::pin!(connect);
396 
397     // Poll `connect` once.
398     poll_fn(|cx| {
399         use std::future::Future;
400 
401         assert_pending!(connect.as_mut().poll(cx));
402         Poll::Ready(())
403     })
404     .await;
405 
406     drop(listener);
407 
408     let err = connect.await.unwrap_err();
409     assert_eq!(err.kind(), io::ErrorKind::ConnectionReset);
410     Ok(())
411 }
412