1 //! In-process memory IO types.
2
3 use crate::io::{AsyncRead, AsyncWrite};
4 use crate::loom::sync::Mutex;
5
6 use bytes::{Buf, BytesMut};
7 use std::{
8 pin::Pin,
9 sync::Arc,
10 task::{self, Poll, Waker},
11 };
12
13 /// A bidirectional pipe to read and write bytes in memory.
14 ///
15 /// A pair of `DuplexStream`s are created together, and they act as a "channel"
16 /// that can be used as in-memory IO types. Writing to one of the pairs will
17 /// allow that data to be read from the other, and vice versa.
18 ///
19 /// # Example
20 ///
21 /// ```
22 /// # async fn ex() -> std::io::Result<()> {
23 /// # use tokio::io::{AsyncReadExt, AsyncWriteExt};
24 /// let (mut client, mut server) = tokio::io::duplex(64);
25 ///
26 /// client.write_all(b"ping").await?;
27 ///
28 /// let mut buf = [0u8; 4];
29 /// server.read_exact(&mut buf).await?;
30 /// assert_eq!(&buf, b"ping");
31 ///
32 /// server.write_all(b"pong").await?;
33 ///
34 /// client.read_exact(&mut buf).await?;
35 /// assert_eq!(&buf, b"pong");
36 /// # Ok(())
37 /// # }
38 /// ```
39 #[derive(Debug)]
40 pub struct DuplexStream {
41 read: Arc<Mutex<Pipe>>,
42 write: Arc<Mutex<Pipe>>,
43 }
44
45 /// A unidirectional IO over a piece of memory.
46 ///
47 /// Data can be written to the pipe, and reading will return that data.
48 #[derive(Debug)]
49 struct Pipe {
50 /// The buffer storing the bytes written, also read from.
51 ///
52 /// Using a `BytesMut` because it has efficient `Buf` and `BufMut`
53 /// functionality already. Additionally, it can try to copy data in the
54 /// same buffer if there read index has advanced far enough.
55 buffer: BytesMut,
56 /// Determines if the write side has been closed.
57 is_closed: bool,
58 /// The maximum amount of bytes that can be written before returning
59 /// `Poll::Pending`.
60 max_buf_size: usize,
61 /// If the `read` side has been polled and is pending, this is the waker
62 /// for that parked task.
63 read_waker: Option<Waker>,
64 /// If the `write` side has filled the `max_buf_size` and returned
65 /// `Poll::Pending`, this is the waker for that parked task.
66 write_waker: Option<Waker>,
67 }
68
69 // ===== impl DuplexStream =====
70
71 /// Create a new pair of `DuplexStream`s that act like a pair of connected sockets.
72 ///
73 /// The `max_buf_size` argument is the maximum amount of bytes that can be
74 /// written to a side before the write returns `Poll::Pending`.
duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream)75 pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) {
76 let one = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
77 let two = Arc::new(Mutex::new(Pipe::new(max_buf_size)));
78
79 (
80 DuplexStream {
81 read: one.clone(),
82 write: two.clone(),
83 },
84 DuplexStream {
85 read: two,
86 write: one,
87 },
88 )
89 }
90
91 impl AsyncRead for DuplexStream {
92 // Previous rustc required this `self` to be `mut`, even though newer
93 // versions recognize it isn't needed to call `lock()`. So for
94 // compatibility, we include the `mut` and `allow` the lint.
95 //
96 // See https://github.com/rust-lang/rust/issues/73592
97 #[allow(unused_mut)]
poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut [u8], ) -> Poll<std::io::Result<usize>>98 fn poll_read(
99 mut self: Pin<&mut Self>,
100 cx: &mut task::Context<'_>,
101 buf: &mut [u8],
102 ) -> Poll<std::io::Result<usize>> {
103 Pin::new(&mut *self.read.lock().unwrap()).poll_read(cx, buf)
104 }
105 }
106
107 impl AsyncWrite for DuplexStream {
108 #[allow(unused_mut)]
poll_write( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll<std::io::Result<usize>>109 fn poll_write(
110 mut self: Pin<&mut Self>,
111 cx: &mut task::Context<'_>,
112 buf: &[u8],
113 ) -> Poll<std::io::Result<usize>> {
114 Pin::new(&mut *self.write.lock().unwrap()).poll_write(cx, buf)
115 }
116
117 #[allow(unused_mut)]
poll_flush( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>>118 fn poll_flush(
119 mut self: Pin<&mut Self>,
120 cx: &mut task::Context<'_>,
121 ) -> Poll<std::io::Result<()>> {
122 Pin::new(&mut *self.write.lock().unwrap()).poll_flush(cx)
123 }
124
125 #[allow(unused_mut)]
poll_shutdown( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>>126 fn poll_shutdown(
127 mut self: Pin<&mut Self>,
128 cx: &mut task::Context<'_>,
129 ) -> Poll<std::io::Result<()>> {
130 Pin::new(&mut *self.write.lock().unwrap()).poll_shutdown(cx)
131 }
132 }
133
134 impl Drop for DuplexStream {
drop(&mut self)135 fn drop(&mut self) {
136 // notify the other side of the closure
137 self.write.lock().unwrap().close();
138 }
139 }
140
141 // ===== impl Pipe =====
142
143 impl Pipe {
new(max_buf_size: usize) -> Self144 fn new(max_buf_size: usize) -> Self {
145 Pipe {
146 buffer: BytesMut::new(),
147 is_closed: false,
148 max_buf_size,
149 read_waker: None,
150 write_waker: None,
151 }
152 }
153
close(&mut self)154 fn close(&mut self) {
155 self.is_closed = true;
156 if let Some(waker) = self.read_waker.take() {
157 waker.wake();
158 }
159 }
160 }
161
162 impl AsyncRead for Pipe {
poll_read( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut [u8], ) -> Poll<std::io::Result<usize>>163 fn poll_read(
164 mut self: Pin<&mut Self>,
165 cx: &mut task::Context<'_>,
166 buf: &mut [u8],
167 ) -> Poll<std::io::Result<usize>> {
168 if self.buffer.has_remaining() {
169 let max = self.buffer.remaining().min(buf.len());
170 self.buffer.copy_to_slice(&mut buf[..max]);
171 if max > 0 {
172 // The passed `buf` might have been empty, don't wake up if
173 // no bytes have been moved.
174 if let Some(waker) = self.write_waker.take() {
175 waker.wake();
176 }
177 }
178 Poll::Ready(Ok(max))
179 } else if self.is_closed {
180 Poll::Ready(Ok(0))
181 } else {
182 self.read_waker = Some(cx.waker().clone());
183 Poll::Pending
184 }
185 }
186 }
187
188 impl AsyncWrite for Pipe {
poll_write( mut self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll<std::io::Result<usize>>189 fn poll_write(
190 mut self: Pin<&mut Self>,
191 cx: &mut task::Context<'_>,
192 buf: &[u8],
193 ) -> Poll<std::io::Result<usize>> {
194 if self.is_closed {
195 return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
196 }
197 let avail = self.max_buf_size - self.buffer.len();
198 if avail == 0 {
199 self.write_waker = Some(cx.waker().clone());
200 return Poll::Pending;
201 }
202
203 let len = buf.len().min(avail);
204 self.buffer.extend_from_slice(&buf[..len]);
205 if let Some(waker) = self.read_waker.take() {
206 waker.wake();
207 }
208 Poll::Ready(Ok(len))
209 }
210
poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>>211 fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<std::io::Result<()>> {
212 Poll::Ready(Ok(()))
213 }
214
poll_shutdown( mut self: Pin<&mut Self>, _: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>>215 fn poll_shutdown(
216 mut self: Pin<&mut Self>,
217 _: &mut task::Context<'_>,
218 ) -> Poll<std::io::Result<()>> {
219 self.close();
220 Poll::Ready(Ok(()))
221 }
222 }
223