1 //! In-process memory IO types.
2
3 use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
4 use crate::loom::sync::Mutex;
5
6 use bytes::{Buf, BytesMut};
7 use std::{
8 pin::Pin,
9 sync::Arc,
10 task::{self, Poll, Waker},
11 };
12
13 /// A bidirectional pipe to read and write bytes in memory.
14 ///
15 /// A pair of `DuplexStream`s are created together, and they act as a "channel"
16 /// that can be used as in-memory IO types. Writing to one of the pairs will
17 /// allow that data to be read from the other, and vice versa.
18 ///
19 /// # Closing a `DuplexStream`
20 ///
21 /// If one end of the `DuplexStream` channel is dropped, any pending reads on
22 /// the other side will continue to read data until the buffer is drained, then
23 /// they will signal EOF by returning 0 bytes. Any writes to the other side,
24 /// including pending ones (that are waiting for free space in the buffer) will
25 /// return `Err(BrokenPipe)` immediately.
26 ///
27 /// # Example
28 ///
29 /// ```
30 /// # async fn ex() -> std::io::Result<()> {
31 /// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
32 /// let (mut client, mut server) = tokio::io::duplex(64);
33 ///
34 /// client.write_all(b"ping").await?;
35 ///
36 /// let mut buf = [0u8; 4];
37 /// server.read_exact(&mut buf).await?;
38 /// assert_eq!(&buf, b"ping");
39 ///
40 /// server.write_all(b"pong").await?;
41 ///
42 /// client.read_exact(&mut buf).await?;
43 /// assert_eq!(&buf, b"pong");
44 /// # Ok(())
45 /// # }
46 /// ```
47 #[derive(Debug)]
48 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
49 pub struct DuplexStream {
50 read: Arc<Mutex<Pipe>>,
51 write: Arc<Mutex<Pipe>>,
52 }
53
54 /// A unidirectional IO over a piece of memory.
55 ///
56 /// Data can be written to the pipe, and reading will return that data.
57 #[derive(Debug)]
58 struct Pipe {
59 /// The buffer storing the bytes written, also read from.
60 ///
61 /// Using a `BytesMut` because it has efficient `Buf` and `BufMut`
62 /// functionality already. Additionally, it can try to copy data in the
63 /// same buffer if there read index has advanced far enough.
64 buffer: BytesMut,
65 /// Determines if the write side has been closed.
66 is_closed: bool,
67 /// The maximum amount of bytes that can be written before returning
68 /// `Poll::Pending`.
69 max_buf_size: usize,
70 /// If the `read` side has been polled and is pending, this is the waker
71 /// for that parked task.
72 read_waker: Option<Waker>,
73 /// If the `write` side has filled the `max_buf_size` and returned
74 /// `Poll::Pending`, this is the waker for that parked task.
75 write_waker: Option<Waker>,
76 }
77
78 // ===== impl DuplexStream =====
79
80 /// Create a new pair of `DuplexStream`s that act like a pair of connected sockets.
81 ///
82 /// The `max_buf_size` argument is the maximum amount of bytes that can be
83 /// written to a side before the write returns `Poll::Pending`.
84 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream)85 pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) {
86 let one = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
87 let two = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
88
89 (
90 DuplexStream {
91 read: one.clone(),
92 write: two.clone(),
93 },
94 DuplexStream {
95 read: two,
96 write: one,
97 },
98 )
99 }
100
101 impl AsyncRead for DuplexStream {
102 // Previous rustc required this `self` to be `mut`, even though newer
103 // versions recognize it isn't needed to call `lock()`. So for
104 // compatibility, we include the `mut` and `allow` the lint.
105 //
106 // See https://github.com/rust-lang/rust/issues/73592
107 #[allow(unused_mut)]
poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>108 fn poll_read(
109 mut self: Pin<&mut Self>,
110 cx: &mut task::Context<'_>,
111 buf: &mut ReadBuf<'_>,
112 ) -> Poll<std::io::Result<()>> {
113 Pin::new(&mut *self.read.lock()).poll_read(cx, buf)
114 }
115 }
116
117 impl AsyncWrite for DuplexStream {
118 #[allow(unused_mut)]
poll_write( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll<std::io::Result<usize>>119 fn poll_write(
120 mut self: Pin<&mut Self>,
121 cx: &mut task::Context<'_>,
122 buf: &[u8],
123 ) -> Poll<std::io::Result<usize>> {
124 Pin::new(&mut *self.write.lock()).poll_write(cx, buf)
125 }
126
127 #[allow(unused_mut)]
poll_flush( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>>128 fn poll_flush(
129 mut self: Pin<&mut Self>,
130 cx: &mut task::Context<'_>,
131 ) -> Poll<std::io::Result<()>> {
132 Pin::new(&mut *self.write.lock()).poll_flush(cx)
133 }
134
135 #[allow(unused_mut)]
poll_shutdown( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>>136 fn poll_shutdown(
137 mut self: Pin<&mut Self>,
138 cx: &mut task::Context<'_>,
139 ) -> Poll<std::io::Result<()>> {
140 Pin::new(&mut *self.write.lock()).poll_shutdown(cx)
141 }
142 }
143
144 impl Drop for DuplexStream {
drop(&mut self)145 fn drop(&mut self) {
146 // notify the other side of the closure
147 self.write.lock().close_write();
148 self.read.lock().close_read();
149 }
150 }
151
152 // ===== impl Pipe =====
153
154 impl Pipe {
new(max_buf_size: usize) -> Self155 fn new(max_buf_size: usize) -> Self {
156 Pipe {
157 buffer: BytesMut::new(),
158 is_closed: false,
159 max_buf_size,
160 read_waker: None,
161 write_waker: None,
162 }
163 }
164
close_write(&mut self)165 fn close_write(&mut self) {
166 self.is_closed = true;
167 // needs to notify any readers that no more data will come
168 if let Some(waker) = self.read_waker.take() {
169 waker.wake();
170 }
171 }
172
close_read(&mut self)173 fn close_read(&mut self) {
174 self.is_closed = true;
175 // needs to notify any writers that they have to abort
176 if let Some(waker) = self.write_waker.take() {
177 waker.wake();
178 }
179 }
180 }
181
182 impl AsyncRead for Pipe {
poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<std::io::Result<()>>183 fn poll_read(
184 mut self: Pin<&mut Self>,
185 cx: &mut task::Context<'_>,
186 buf: &mut ReadBuf<'_>,
187 ) -> Poll<std::io::Result<()>> {
188 if self.buffer.has_remaining() {
189 let max = self.buffer.remaining().min(buf.remaining());
190 buf.put_slice(&self.buffer[..max]);
191 self.buffer.advance(max);
192 if max > 0 {
193 // The passed `buf` might have been empty, don't wake up if
194 // no bytes have been moved.
195 if let Some(waker) = self.write_waker.take() {
196 waker.wake();
197 }
198 }
199 Poll::Ready(Ok(()))
200 } else if self.is_closed {
201 Poll::Ready(Ok(()))
202 } else {
203 self.read_waker = Some(cx.waker().clone());
204 Poll::Pending
205 }
206 }
207 }
208
209 impl AsyncWrite for Pipe {
poll_write( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll<std::io::Result<usize>>210 fn poll_write(
211 mut self: Pin<&mut Self>,
212 cx: &mut task::Context<'_>,
213 buf: &[u8],
214 ) -> Poll<std::io::Result<usize>> {
215 if self.is_closed {
216 return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
217 }
218 let avail = self.max_buf_size - self.buffer.len();
219 if avail == 0 {
220 self.write_waker = Some(cx.waker().clone());
221 return Poll::Pending;
222 }
223
224 let len = buf.len().min(avail);
225 self.buffer.extend_from_slice(&buf[..len]);
226 if let Some(waker) = self.read_waker.take() {
227 waker.wake();
228 }
229 Poll::Ready(Ok(len))
230 }
231
poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>>232 fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
233 Poll::Ready(Ok(()))
234 }
235
poll_shutdown( mut self: Pin<&mut Self>, _: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>>236 fn poll_shutdown(
237 mut self: Pin<&mut Self>,
238 _: &mut task::Context<'_>,
239 ) -> Poll<std::io::Result<()>> {
240 self.close_write();
241 Poll::Ready(Ok(()))
242 }
243 }
244