1 #![cfg(feature = "full")]
2 #![warn(rust_2018_idioms)]
3 #![cfg(unix)]
4 
5 use std::io;
6 use std::task::Poll;
7 
8 use tokio::io::{AsyncReadExt, AsyncWriteExt, Interest};
9 use tokio::net::{UnixListener, UnixStream};
10 use tokio_test::{assert_ok, assert_pending, assert_ready_ok, task};
11 
12 use futures::future::{poll_fn, try_join};
13 
14 #[tokio::test]
accept_read_write() -> std::io::Result<()>15 async fn accept_read_write() -> std::io::Result<()> {
16     let dir = tempfile::Builder::new()
17         .prefix("tokio-uds-tests")
18         .tempdir()
19         .unwrap();
20     let sock_path = dir.path().join("connect.sock");
21 
22     let listener = UnixListener::bind(&sock_path)?;
23 
24     let accept = listener.accept();
25     let connect = UnixStream::connect(&sock_path);
26     let ((mut server, _), mut client) = try_join(accept, connect).await?;
27 
28     // Write to the client. TODO: Switch to write_all.
29     let write_len = client.write(b"hello").await?;
30     assert_eq!(write_len, 5);
31     drop(client);
32     // Read from the server. TODO: Switch to read_to_end.
33     let mut buf = [0u8; 5];
34     server.read_exact(&mut buf).await?;
35     assert_eq!(&buf, b"hello");
36     let len = server.read(&mut buf).await?;
37     assert_eq!(len, 0);
38     Ok(())
39 }
40 
41 #[tokio::test]
shutdown() -> std::io::Result<()>42 async fn shutdown() -> std::io::Result<()> {
43     let dir = tempfile::Builder::new()
44         .prefix("tokio-uds-tests")
45         .tempdir()
46         .unwrap();
47     let sock_path = dir.path().join("connect.sock");
48 
49     let listener = UnixListener::bind(&sock_path)?;
50 
51     let accept = listener.accept();
52     let connect = UnixStream::connect(&sock_path);
53     let ((mut server, _), mut client) = try_join(accept, connect).await?;
54 
55     // Shut down the client
56     AsyncWriteExt::shutdown(&mut client).await?;
57     // Read from the server should return 0 to indicate the channel has been closed.
58     let mut buf = [0u8; 1];
59     let n = server.read(&mut buf).await?;
60     assert_eq!(n, 0);
61     Ok(())
62 }
63 
64 #[tokio::test]
try_read_write() -> std::io::Result<()>65 async fn try_read_write() -> std::io::Result<()> {
66     let msg = b"hello world";
67 
68     let dir = tempfile::tempdir()?;
69     let bind_path = dir.path().join("bind.sock");
70 
71     // Create listener
72     let listener = UnixListener::bind(&bind_path)?;
73 
74     // Create socket pair
75     let client = UnixStream::connect(&bind_path).await?;
76 
77     let (server, _) = listener.accept().await?;
78     let mut written = msg.to_vec();
79 
80     // Track the server receiving data
81     let mut readable = task::spawn(server.readable());
82     assert_pending!(readable.poll());
83 
84     // Write data.
85     client.writable().await?;
86     assert_eq!(msg.len(), client.try_write(msg)?);
87 
88     // The task should be notified
89     while !readable.is_woken() {
90         tokio::task::yield_now().await;
91     }
92 
93     // Fill the write buffer
94     loop {
95         // Still ready
96         let mut writable = task::spawn(client.writable());
97         assert_ready_ok!(writable.poll());
98 
99         match client.try_write(msg) {
100             Ok(n) => written.extend(&msg[..n]),
101             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
102                 break;
103             }
104             Err(e) => panic!("error = {:?}", e),
105         }
106     }
107 
108     {
109         // Write buffer full
110         let mut writable = task::spawn(client.writable());
111         assert_pending!(writable.poll());
112 
113         // Drain the socket from the server end
114         let mut read = vec![0; written.len()];
115         let mut i = 0;
116 
117         while i < read.len() {
118             server.readable().await?;
119 
120             match server.try_read(&mut read[i..]) {
121                 Ok(n) => i += n,
122                 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
123                 Err(e) => panic!("error = {:?}", e),
124             }
125         }
126 
127         assert_eq!(read, written);
128     }
129 
130     // Now, we listen for shutdown
131     drop(client);
132 
133     loop {
134         let ready = server.ready(Interest::READABLE).await?;
135 
136         if ready.is_read_closed() {
137             break;
138         } else {
139             tokio::task::yield_now().await;
140         }
141     }
142 
143     Ok(())
144 }
145 
create_pair() -> (UnixStream, UnixStream)146 async fn create_pair() -> (UnixStream, UnixStream) {
147     let dir = assert_ok!(tempfile::tempdir());
148     let bind_path = dir.path().join("bind.sock");
149 
150     let listener = assert_ok!(UnixListener::bind(&bind_path));
151 
152     let accept = listener.accept();
153     let connect = UnixStream::connect(&bind_path);
154     let ((server, _), client) = assert_ok!(try_join(accept, connect).await);
155 
156     (client, server)
157 }
158 
159 macro_rules! assert_readable_by_polling {
160     ($stream:expr) => {
161         assert_ok!(poll_fn(|cx| $stream.poll_read_ready(cx)).await);
162     };
163 }
164 
165 macro_rules! assert_not_readable_by_polling {
166     ($stream:expr) => {
167         poll_fn(|cx| {
168             assert_pending!($stream.poll_read_ready(cx));
169             Poll::Ready(())
170         })
171         .await;
172     };
173 }
174 
175 macro_rules! assert_writable_by_polling {
176     ($stream:expr) => {
177         assert_ok!(poll_fn(|cx| $stream.poll_write_ready(cx)).await);
178     };
179 }
180 
181 macro_rules! assert_not_writable_by_polling {
182     ($stream:expr) => {
183         poll_fn(|cx| {
184             assert_pending!($stream.poll_write_ready(cx));
185             Poll::Ready(())
186         })
187         .await;
188     };
189 }
190 
191 #[tokio::test]
poll_read_ready()192 async fn poll_read_ready() {
193     let (mut client, mut server) = create_pair().await;
194 
195     // Initial state - not readable.
196     assert_not_readable_by_polling!(server);
197 
198     // There is data in the buffer - readable.
199     assert_ok!(client.write_all(b"ping").await);
200     assert_readable_by_polling!(server);
201 
202     // Readable until calls to `poll_read` return `Poll::Pending`.
203     let mut buf = [0u8; 4];
204     assert_ok!(server.read_exact(&mut buf).await);
205     assert_readable_by_polling!(server);
206     read_until_pending(&mut server);
207     assert_not_readable_by_polling!(server);
208 
209     // Detect the client disconnect.
210     drop(client);
211     assert_readable_by_polling!(server);
212 }
213 
214 #[tokio::test]
poll_write_ready()215 async fn poll_write_ready() {
216     let (mut client, server) = create_pair().await;
217 
218     // Initial state - writable.
219     assert_writable_by_polling!(client);
220 
221     // No space to write - not writable.
222     write_until_pending(&mut client);
223     assert_not_writable_by_polling!(client);
224 
225     // Detect the server disconnect.
226     drop(server);
227     assert_writable_by_polling!(client);
228 }
229 
read_until_pending(stream: &mut UnixStream)230 fn read_until_pending(stream: &mut UnixStream) {
231     let mut buf = vec![0u8; 1024 * 1024];
232     loop {
233         match stream.try_read(&mut buf) {
234             Ok(_) => (),
235             Err(err) => {
236                 assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
237                 break;
238             }
239         }
240     }
241 }
242 
write_until_pending(stream: &mut UnixStream)243 fn write_until_pending(stream: &mut UnixStream) {
244     let buf = vec![0u8; 1024 * 1024];
245     loop {
246         match stream.try_write(&buf) {
247             Ok(_) => (),
248             Err(err) => {
249                 assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
250                 break;
251             }
252         }
253     }
254 }
255 
256 #[tokio::test]
try_read_buf() -> std::io::Result<()>257 async fn try_read_buf() -> std::io::Result<()> {
258     let msg = b"hello world";
259 
260     let dir = tempfile::tempdir()?;
261     let bind_path = dir.path().join("bind.sock");
262 
263     // Create listener
264     let listener = UnixListener::bind(&bind_path)?;
265 
266     // Create socket pair
267     let client = UnixStream::connect(&bind_path).await?;
268 
269     let (server, _) = listener.accept().await?;
270     let mut written = msg.to_vec();
271 
272     // Track the server receiving data
273     let mut readable = task::spawn(server.readable());
274     assert_pending!(readable.poll());
275 
276     // Write data.
277     client.writable().await?;
278     assert_eq!(msg.len(), client.try_write(msg)?);
279 
280     // The task should be notified
281     while !readable.is_woken() {
282         tokio::task::yield_now().await;
283     }
284 
285     // Fill the write buffer
286     loop {
287         // Still ready
288         let mut writable = task::spawn(client.writable());
289         assert_ready_ok!(writable.poll());
290 
291         match client.try_write(msg) {
292             Ok(n) => written.extend(&msg[..n]),
293             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
294                 break;
295             }
296             Err(e) => panic!("error = {:?}", e),
297         }
298     }
299 
300     {
301         // Write buffer full
302         let mut writable = task::spawn(client.writable());
303         assert_pending!(writable.poll());
304 
305         // Drain the socket from the server end
306         let mut read = Vec::with_capacity(written.len());
307         let mut i = 0;
308 
309         while i < read.capacity() {
310             server.readable().await?;
311 
312             match server.try_read_buf(&mut read) {
313                 Ok(n) => i += n,
314                 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
315                 Err(e) => panic!("error = {:?}", e),
316             }
317         }
318 
319         assert_eq!(read, written);
320     }
321 
322     // Now, we listen for shutdown
323     drop(client);
324 
325     loop {
326         let ready = server.ready(Interest::READABLE).await?;
327 
328         if ready.is_read_closed() {
329             break;
330         } else {
331             tokio::task::yield_now().await;
332         }
333     }
334 
335     Ok(())
336 }
337