1 //! In-process memory IO types.
2 
3 use crate::io::{AsyncRead, AsyncWrite};
4 use crate::loom::sync::Mutex;
5 
6 use bytes::{Buf, BytesMut};
7 use std::{
8     pin::Pin,
9     sync::Arc,
10     task::{self, Poll, Waker},
11 };
12 
13 /// A bidirectional pipe to read and write bytes in memory.
14 ///
15 /// A pair of `DuplexStream`s are created together, and they act as a "channel"
16 /// that can be used as in-memory IO types. Writing to one of the pairs will
17 /// allow that data to be read from the other, and vice versa.
18 ///
19 /// # Example
20 ///
21 /// ```
22 /// # async fn ex() -> std::io::Result<()> {
23 /// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
24 /// let (mut client, mut server) = tokio::io::duplex(64);
25 ///
26 /// client.write_all(b"ping").await?;
27 ///
28 /// let mut buf = [0u8; 4];
29 /// server.read_exact(&mut buf).await?;
30 /// assert_eq!(&buf, b"ping");
31 ///
32 /// server.write_all(b"pong").await?;
33 ///
34 /// client.read_exact(&mut buf).await?;
35 /// assert_eq!(&buf, b"pong");
36 /// # Ok(())
37 /// # }
38 /// ```
39 #[derive(Debug)]
40 pub struct DuplexStream {
41     read: Arc<Mutex<Pipe>>,
42     write: Arc<Mutex<Pipe>>,
43 }
44 
45 /// A unidirectional IO over a piece of memory.
46 ///
47 /// Data can be written to the pipe, and reading will return that data.
48 #[derive(Debug)]
49 struct Pipe {
50     /// The buffer storing the bytes written, also read from.
51     ///
52     /// Using a `BytesMut` because it has efficient `Buf` and `BufMut`
53     /// functionality already. Additionally, it can try to copy data in the
54     /// same buffer if there read index has advanced far enough.
55     buffer: BytesMut,
56     /// Determines if the write side has been closed.
57     is_closed: bool,
58     /// The maximum amount of bytes that can be written before returning
59     /// `Poll::Pending`.
60     max_buf_size: usize,
61     /// If the `read` side has been polled and is pending, this is the waker
62     /// for that parked task.
63     read_waker: Option<Waker>,
64     /// If the `write` side has filled the `max_buf_size` and returned
65     /// `Poll::Pending`, this is the waker for that parked task.
66     write_waker: Option<Waker>,
67 }
68 
69 // ===== impl DuplexStream =====
70 
71 /// Create a new pair of `DuplexStream`s that act like a pair of connected sockets.
72 ///
73 /// The `max_buf_size` argument is the maximum amount of bytes that can be
74 /// written to a side before the write returns `Poll::Pending`.
duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream)75 pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) {
76     let one = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
77     let two = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
78 
79     (
80         DuplexStream {
81             read: one.clone(),
82             write: two.clone(),
83         },
84         DuplexStream {
85             read: two,
86             write: one,
87         },
88     )
89 }
90 
91 impl AsyncRead for DuplexStream {
92     // Previous rustc required this `self` to be `mut`, even though newer
93     // versions recognize it isn't needed to call `lock()`. So for
94     // compatibility, we include the `mut` and `allow` the lint.
95     //
96     // See https://github.com/rust-lang/rust/issues/73592
97     #[allow(unused_mut)]
poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut [u8], ) -> Poll<std::io::Result<usize>>98     fn poll_read(
99         mut self: Pin<&mut Self>,
100         cx: &mut task::Context<'_>,
101         buf: &mut [u8],
102     ) -> Poll<std::io::Result<usize>> {
103         Pin::new(&mut *self.read.lock().unwrap()).poll_read(cx, buf)
104     }
105 }
106 
107 impl AsyncWrite for DuplexStream {
108     #[allow(unused_mut)]
poll_write( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll<std::io::Result<usize>>109     fn poll_write(
110         mut self: Pin<&mut Self>,
111         cx: &mut task::Context<'_>,
112         buf: &[u8],
113     ) -> Poll<std::io::Result<usize>> {
114         Pin::new(&mut *self.write.lock().unwrap()).poll_write(cx, buf)
115     }
116 
117     #[allow(unused_mut)]
poll_flush( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>>118     fn poll_flush(
119         mut self: Pin<&mut Self>,
120         cx: &mut task::Context<'_>,
121     ) -> Poll<std::io::Result<()>> {
122         Pin::new(&mut *self.write.lock().unwrap()).poll_flush(cx)
123     }
124 
125     #[allow(unused_mut)]
poll_shutdown( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>>126     fn poll_shutdown(
127         mut self: Pin<&mut Self>,
128         cx: &mut task::Context<'_>,
129     ) -> Poll<std::io::Result<()>> {
130         Pin::new(&mut *self.write.lock().unwrap()).poll_shutdown(cx)
131     }
132 }
133 
134 impl Drop for DuplexStream {
drop(&mut self)135     fn drop(&mut self) {
136         // notify the other side of the closure
137         self.write.lock().unwrap().close();
138     }
139 }
140 
141 // ===== impl Pipe =====
142 
143 impl Pipe {
new(max_buf_size: usize) -> Self144     fn new(max_buf_size: usize) -> Self {
145         Pipe {
146             buffer: BytesMut::new(),
147             is_closed: false,
148             max_buf_size,
149             read_waker: None,
150             write_waker: None,
151         }
152     }
153 
close(&mut self)154     fn close(&mut self) {
155         self.is_closed = true;
156         if let Some(waker) = self.read_waker.take() {
157             waker.wake();
158         }
159     }
160 }
161 
162 impl AsyncRead for Pipe {
poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut [u8], ) -> Poll<std::io::Result<usize>>163     fn poll_read(
164         mut self: Pin<&mut Self>,
165         cx: &mut task::Context<'_>,
166         buf: &mut [u8],
167     ) -> Poll<std::io::Result<usize>> {
168         if self.buffer.has_remaining() {
169             let max = self.buffer.remaining().min(buf.len());
170             self.buffer.copy_to_slice(&mut buf[..max]);
171             if max > 0 {
172                 // The passed `buf` might have been empty, don't wake up if
173                 // no bytes have been moved.
174                 if let Some(waker) = self.write_waker.take() {
175                     waker.wake();
176                 }
177             }
178             Poll::Ready(Ok(max))
179         } else if self.is_closed {
180             Poll::Ready(Ok(0))
181         } else {
182             self.read_waker = Some(cx.waker().clone());
183             Poll::Pending
184         }
185     }
186 }
187 
188 impl AsyncWrite for Pipe {
poll_write( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll<std::io::Result<usize>>189     fn poll_write(
190         mut self: Pin<&mut Self>,
191         cx: &mut task::Context<'_>,
192         buf: &[u8],
193     ) -> Poll<std::io::Result<usize>> {
194         if self.is_closed {
195             return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
196         }
197         let avail = self.max_buf_size - self.buffer.len();
198         if avail == 0 {
199             self.write_waker = Some(cx.waker().clone());
200             return Poll::Pending;
201         }
202 
203         let len = buf.len().min(avail);
204         self.buffer.extend_from_slice(&buf[..len]);
205         if let Some(waker) = self.read_waker.take() {
206             waker.wake();
207         }
208         Poll::Ready(Ok(len))
209     }
210 
poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>>211     fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
212         Poll::Ready(Ok(()))
213     }
214 
poll_shutdown( mut self: Pin<&mut Self>, _: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>>215     fn poll_shutdown(
216         mut self: Pin<&mut Self>,
217         _: &mut task::Context<'_>,
218     ) -> Poll<std::io::Result<()>> {
219         self.close();
220         Poll::Ready(Ok(()))
221     }
222 }
223