1 #![cfg(feature = "full")]
2 #![warn(rust_2018_idioms)]
3 #![cfg(unix)]
4
5 use std::io;
6 use std::task::Poll;
7
8 use tokio::io::{AsyncReadExt, AsyncWriteExt, Interest};
9 use tokio::net::{UnixListener, UnixStream};
10 use tokio_test::{assert_ok, assert_pending, assert_ready_ok, task};
11
12 use futures::future::{poll_fn, try_join};
13
14 #[tokio::test]
accept_read_write() -> std::io::Result<()>15 async fn accept_read_write() -> std::io::Result<()> {
16 let dir = tempfile::Builder::new()
17 .prefix("tokio-uds-tests")
18 .tempdir()
19 .unwrap();
20 let sock_path = dir.path().join("connect.sock");
21
22 let listener = UnixListener::bind(&sock_path)?;
23
24 let accept = listener.accept();
25 let connect = UnixStream::connect(&sock_path);
26 let ((mut server, _), mut client) = try_join(accept, connect).await?;
27
28 // Write to the client. TODO: Switch to write_all.
29 let write_len = client.write(b"hello").await?;
30 assert_eq!(write_len, 5);
31 drop(client);
32 // Read from the server. TODO: Switch to read_to_end.
33 let mut buf = [0u8; 5];
34 server.read_exact(&mut buf).await?;
35 assert_eq!(&buf, b"hello");
36 let len = server.read(&mut buf).await?;
37 assert_eq!(len, 0);
38 Ok(())
39 }
40
41 #[tokio::test]
shutdown() -> std::io::Result<()>42 async fn shutdown() -> std::io::Result<()> {
43 let dir = tempfile::Builder::new()
44 .prefix("tokio-uds-tests")
45 .tempdir()
46 .unwrap();
47 let sock_path = dir.path().join("connect.sock");
48
49 let listener = UnixListener::bind(&sock_path)?;
50
51 let accept = listener.accept();
52 let connect = UnixStream::connect(&sock_path);
53 let ((mut server, _), mut client) = try_join(accept, connect).await?;
54
55 // Shut down the client
56 AsyncWriteExt::shutdown(&mut client).await?;
57 // Read from the server should return 0 to indicate the channel has been closed.
58 let mut buf = [0u8; 1];
59 let n = server.read(&mut buf).await?;
60 assert_eq!(n, 0);
61 Ok(())
62 }
63
64 #[tokio::test]
try_read_write() -> std::io::Result<()>65 async fn try_read_write() -> std::io::Result<()> {
66 let msg = b"hello world";
67
68 let dir = tempfile::tempdir()?;
69 let bind_path = dir.path().join("bind.sock");
70
71 // Create listener
72 let listener = UnixListener::bind(&bind_path)?;
73
74 // Create socket pair
75 let client = UnixStream::connect(&bind_path).await?;
76
77 let (server, _) = listener.accept().await?;
78 let mut written = msg.to_vec();
79
80 // Track the server receiving data
81 let mut readable = task::spawn(server.readable());
82 assert_pending!(readable.poll());
83
84 // Write data.
85 client.writable().await?;
86 assert_eq!(msg.len(), client.try_write(msg)?);
87
88 // The task should be notified
89 while !readable.is_woken() {
90 tokio::task::yield_now().await;
91 }
92
93 // Fill the write buffer
94 loop {
95 // Still ready
96 let mut writable = task::spawn(client.writable());
97 assert_ready_ok!(writable.poll());
98
99 match client.try_write(msg) {
100 Ok(n) => written.extend(&msg[..n]),
101 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
102 break;
103 }
104 Err(e) => panic!("error = {:?}", e),
105 }
106 }
107
108 {
109 // Write buffer full
110 let mut writable = task::spawn(client.writable());
111 assert_pending!(writable.poll());
112
113 // Drain the socket from the server end
114 let mut read = vec![0; written.len()];
115 let mut i = 0;
116
117 while i < read.len() {
118 server.readable().await?;
119
120 match server.try_read(&mut read[i..]) {
121 Ok(n) => i += n,
122 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
123 Err(e) => panic!("error = {:?}", e),
124 }
125 }
126
127 assert_eq!(read, written);
128 }
129
130 // Now, we listen for shutdown
131 drop(client);
132
133 loop {
134 let ready = server.ready(Interest::READABLE).await?;
135
136 if ready.is_read_closed() {
137 break;
138 } else {
139 tokio::task::yield_now().await;
140 }
141 }
142
143 Ok(())
144 }
145
create_pair() -> (UnixStream, UnixStream)146 async fn create_pair() -> (UnixStream, UnixStream) {
147 let dir = assert_ok!(tempfile::tempdir());
148 let bind_path = dir.path().join("bind.sock");
149
150 let listener = assert_ok!(UnixListener::bind(&bind_path));
151
152 let accept = listener.accept();
153 let connect = UnixStream::connect(&bind_path);
154 let ((server, _), client) = assert_ok!(try_join(accept, connect).await);
155
156 (client, server)
157 }
158
159 macro_rules! assert_readable_by_polling {
160 ($stream:expr) => {
161 assert_ok!(poll_fn(|cx| $stream.poll_read_ready(cx)).await);
162 };
163 }
164
165 macro_rules! assert_not_readable_by_polling {
166 ($stream:expr) => {
167 poll_fn(|cx| {
168 assert_pending!($stream.poll_read_ready(cx));
169 Poll::Ready(())
170 })
171 .await;
172 };
173 }
174
175 macro_rules! assert_writable_by_polling {
176 ($stream:expr) => {
177 assert_ok!(poll_fn(|cx| $stream.poll_write_ready(cx)).await);
178 };
179 }
180
181 macro_rules! assert_not_writable_by_polling {
182 ($stream:expr) => {
183 poll_fn(|cx| {
184 assert_pending!($stream.poll_write_ready(cx));
185 Poll::Ready(())
186 })
187 .await;
188 };
189 }
190
191 #[tokio::test]
poll_read_ready()192 async fn poll_read_ready() {
193 let (mut client, mut server) = create_pair().await;
194
195 // Initial state - not readable.
196 assert_not_readable_by_polling!(server);
197
198 // There is data in the buffer - readable.
199 assert_ok!(client.write_all(b"ping").await);
200 assert_readable_by_polling!(server);
201
202 // Readable until calls to `poll_read` return `Poll::Pending`.
203 let mut buf = [0u8; 4];
204 assert_ok!(server.read_exact(&mut buf).await);
205 assert_readable_by_polling!(server);
206 read_until_pending(&mut server);
207 assert_not_readable_by_polling!(server);
208
209 // Detect the client disconnect.
210 drop(client);
211 assert_readable_by_polling!(server);
212 }
213
214 #[tokio::test]
poll_write_ready()215 async fn poll_write_ready() {
216 let (mut client, server) = create_pair().await;
217
218 // Initial state - writable.
219 assert_writable_by_polling!(client);
220
221 // No space to write - not writable.
222 write_until_pending(&mut client);
223 assert_not_writable_by_polling!(client);
224
225 // Detect the server disconnect.
226 drop(server);
227 assert_writable_by_polling!(client);
228 }
229
read_until_pending(stream: &mut UnixStream)230 fn read_until_pending(stream: &mut UnixStream) {
231 let mut buf = vec![0u8; 1024 * 1024];
232 loop {
233 match stream.try_read(&mut buf) {
234 Ok(_) => (),
235 Err(err) => {
236 assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
237 break;
238 }
239 }
240 }
241 }
242
write_until_pending(stream: &mut UnixStream)243 fn write_until_pending(stream: &mut UnixStream) {
244 let buf = vec![0u8; 1024 * 1024];
245 loop {
246 match stream.try_write(&buf) {
247 Ok(_) => (),
248 Err(err) => {
249 assert_eq!(err.kind(), io::ErrorKind::WouldBlock);
250 break;
251 }
252 }
253 }
254 }
255
256 #[tokio::test]
try_read_buf() -> std::io::Result<()>257 async fn try_read_buf() -> std::io::Result<()> {
258 let msg = b"hello world";
259
260 let dir = tempfile::tempdir()?;
261 let bind_path = dir.path().join("bind.sock");
262
263 // Create listener
264 let listener = UnixListener::bind(&bind_path)?;
265
266 // Create socket pair
267 let client = UnixStream::connect(&bind_path).await?;
268
269 let (server, _) = listener.accept().await?;
270 let mut written = msg.to_vec();
271
272 // Track the server receiving data
273 let mut readable = task::spawn(server.readable());
274 assert_pending!(readable.poll());
275
276 // Write data.
277 client.writable().await?;
278 assert_eq!(msg.len(), client.try_write(msg)?);
279
280 // The task should be notified
281 while !readable.is_woken() {
282 tokio::task::yield_now().await;
283 }
284
285 // Fill the write buffer
286 loop {
287 // Still ready
288 let mut writable = task::spawn(client.writable());
289 assert_ready_ok!(writable.poll());
290
291 match client.try_write(msg) {
292 Ok(n) => written.extend(&msg[..n]),
293 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
294 break;
295 }
296 Err(e) => panic!("error = {:?}", e),
297 }
298 }
299
300 {
301 // Write buffer full
302 let mut writable = task::spawn(client.writable());
303 assert_pending!(writable.poll());
304
305 // Drain the socket from the server end
306 let mut read = Vec::with_capacity(written.len());
307 let mut i = 0;
308
309 while i < read.capacity() {
310 server.readable().await?;
311
312 match server.try_read_buf(&mut read) {
313 Ok(n) => i += n,
314 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue,
315 Err(e) => panic!("error = {:?}", e),
316 }
317 }
318
319 assert_eq!(read, written);
320 }
321
322 // Now, we listen for shutdown
323 drop(client);
324
325 loop {
326 let ready = server.ready(Interest::READABLE).await?;
327
328 if ready.is_read_closed() {
329 break;
330 } else {
331 tokio::task::yield_now().await;
332 }
333 }
334
335 Ok(())
336 }
337