1 //! In-process memory IO types.
2 
3 use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
4 use crate::loom::sync::Mutex;
5 
6 use bytes::{Buf, BytesMut};
7 use std::{
8     pin::Pin,
9     sync::Arc,
10     task::{self, Poll, Waker},
11 };
12 
13 /// A bidirectional pipe to read and write bytes in memory.
14 ///
15 /// A pair of `DuplexStream`s are created together, and they act as a "channel"
16 /// that can be used as in-memory IO types. Writing to one of the pairs will
17 /// allow that data to be read from the other, and vice versa.
18 ///
19 /// # Closing a `DuplexStream`
20 ///
21 /// If one end of the `DuplexStream` channel is dropped, any pending reads on
22 /// the other side will continue to read data until the buffer is drained, then
23 /// they will signal EOF by returning 0 bytes. Any writes to the other side,
24 /// including pending ones (that are waiting for free space in the buffer) will
25 /// return `Err(BrokenPipe)` immediately.
26 ///
27 /// # Example
28 ///
29 /// ```
30 /// # async fn ex() -> std::io::Result<()> {
31 /// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
32 /// let (mut client, mut server) = tokio::io::duplex(64);
33 ///
34 /// client.write_all(b"ping").await?;
35 ///
36 /// let mut buf = [0u8; 4];
37 /// server.read_exact(&mut buf).await?;
38 /// assert_eq!(&buf, b"ping");
39 ///
40 /// server.write_all(b"pong").await?;
41 ///
42 /// client.read_exact(&mut buf).await?;
43 /// assert_eq!(&buf, b"pong");
44 /// # Ok(())
45 /// # }
46 /// ```
47 #[derive(Debug)]
48 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
49 pub struct DuplexStream {
50     read: Arc<Mutex<Pipe>>,
51     write: Arc<Mutex<Pipe>>,
52 }
53 
54 /// A unidirectional IO over a piece of memory.
55 ///
56 /// Data can be written to the pipe, and reading will return that data.
57 #[derive(Debug)]
58 struct Pipe {
59     /// The buffer storing the bytes written, also read from.
60     ///
61     /// Using a `BytesMut` because it has efficient `Buf` and `BufMut`
62     /// functionality already. Additionally, it can try to copy data in the
63     /// same buffer if there read index has advanced far enough.
64     buffer: BytesMut,
65     /// Determines if the write side has been closed.
66     is_closed: bool,
67     /// The maximum amount of bytes that can be written before returning
68     /// `Poll::Pending`.
69     max_buf_size: usize,
70     /// If the `read` side has been polled and is pending, this is the waker
71     /// for that parked task.
72     read_waker: Option<Waker>,
73     /// If the `write` side has filled the `max_buf_size` and returned
74     /// `Poll::Pending`, this is the waker for that parked task.
75     write_waker: Option<Waker>,
76 }
77 
78 // ===== impl DuplexStream =====
79 
80 /// Create a new pair of `DuplexStream`s that act like a pair of connected sockets.
81 ///
82 /// The `max_buf_size` argument is the maximum amount of bytes that can be
83 /// written to a side before the write returns `Poll::Pending`.
84 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream)85 pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) {
86     let one = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
87     let two = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
88 
89     (
90         DuplexStream {
91             read: one.clone(),
92             write: two.clone(),
93         },
94         DuplexStream {
95             read: two,
96             write: one,
97         },
98     )
99 }
100 
101 impl AsyncRead for DuplexStream {
102     // Previous rustc required this `self` to be `mut`, even though newer
103     // versions recognize it isn't needed to call `lock()`. So for
104     // compatibility, we include the `mut` and `allow` the lint.
105     //
106     // See https://github.com/rust-lang/rust/issues/73592
107     #[allow(unused_mut)]
poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>108     fn poll_read(
109         mut self: Pin<&mut Self>,
110         cx: &mut task::Context<'_>,
111         buf: &mut ReadBuf<'_>,
112     ) -> Poll<std::io::Result<()>> {
113         Pin::new(&mut *self.read.lock()).poll_read(cx, buf)
114     }
115 }
116 
117 impl AsyncWrite for DuplexStream {
118     #[allow(unused_mut)]
poll_write( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll<std::io::Result<usize>>119     fn poll_write(
120         mut self: Pin<&mut Self>,
121         cx: &mut task::Context<'_>,
122         buf: &[u8],
123     ) -> Poll<std::io::Result<usize>> {
124         Pin::new(&mut *self.write.lock()).poll_write(cx, buf)
125     }
126 
127     #[allow(unused_mut)]
poll_flush( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>>128     fn poll_flush(
129         mut self: Pin<&mut Self>,
130         cx: &mut task::Context<'_>,
131     ) -> Poll<std::io::Result<()>> {
132         Pin::new(&mut *self.write.lock()).poll_flush(cx)
133     }
134 
135     #[allow(unused_mut)]
poll_shutdown( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>>136     fn poll_shutdown(
137         mut self: Pin<&mut Self>,
138         cx: &mut task::Context<'_>,
139     ) -> Poll<std::io::Result<()>> {
140         Pin::new(&mut *self.write.lock()).poll_shutdown(cx)
141     }
142 }
143 
144 impl Drop for DuplexStream {
drop(&mut self)145     fn drop(&mut self) {
146         // notify the other side of the closure
147         self.write.lock().close_write();
148         self.read.lock().close_read();
149     }
150 }
151 
152 // ===== impl Pipe =====
153 
154 impl Pipe {
new(max_buf_size: usize) -> Self155     fn new(max_buf_size: usize) -> Self {
156         Pipe {
157             buffer: BytesMut::new(),
158             is_closed: false,
159             max_buf_size,
160             read_waker: None,
161             write_waker: None,
162         }
163     }
164 
close_write(&mut self)165     fn close_write(&mut self) {
166         self.is_closed = true;
167         // needs to notify any readers that no more data will come
168         if let Some(waker) = self.read_waker.take() {
169             waker.wake();
170         }
171     }
172 
close_read(&mut self)173     fn close_read(&mut self) {
174         self.is_closed = true;
175         // needs to notify any writers that they have to abort
176         if let Some(waker) = self.write_waker.take() {
177             waker.wake();
178         }
179     }
180 }
181 
182 impl AsyncRead for Pipe {
poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>183     fn poll_read(
184         mut self: Pin<&mut Self>,
185         cx: &mut task::Context<'_>,
186         buf: &mut ReadBuf<'_>,
187     ) -> Poll<std::io::Result<()>> {
188         if self.buffer.has_remaining() {
189             let max = self.buffer.remaining().min(buf.remaining());
190             buf.put_slice(&self.buffer[..max]);
191             self.buffer.advance(max);
192             if max > 0 {
193                 // The passed `buf` might have been empty, don't wake up if
194                 // no bytes have been moved.
195                 if let Some(waker) = self.write_waker.take() {
196                     waker.wake();
197                 }
198             }
199             Poll::Ready(Ok(()))
200         } else if self.is_closed {
201             Poll::Ready(Ok(()))
202         } else {
203             self.read_waker = Some(cx.waker().clone());
204             Poll::Pending
205         }
206     }
207 }
208 
209 impl AsyncWrite for Pipe {
poll_write( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll<std::io::Result<usize>>210     fn poll_write(
211         mut self: Pin<&mut Self>,
212         cx: &mut task::Context<'_>,
213         buf: &[u8],
214     ) -> Poll<std::io::Result<usize>> {
215         if self.is_closed {
216             return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
217         }
218         let avail = self.max_buf_size - self.buffer.len();
219         if avail == 0 {
220             self.write_waker = Some(cx.waker().clone());
221             return Poll::Pending;
222         }
223 
224         let len = buf.len().min(avail);
225         self.buffer.extend_from_slice(&buf[..len]);
226         if let Some(waker) = self.read_waker.take() {
227             waker.wake();
228         }
229         Poll::Ready(Ok(len))
230     }
231 
poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>>232     fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
233         Poll::Ready(Ok(()))
234     }
235 
poll_shutdown( mut self: Pin<&mut Self>, _: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>>236     fn poll_shutdown(
237         mut self: Pin<&mut Self>,
238         _: &mut task::Context<'_>,
239     ) -> Poll<std::io::Result<()>> {
240         self.close_write();
241         Poll::Ready(Ok(()))
242     }
243 }
244