1 //! `TcpStream` owned split support.
2 //!
3 //! A `TcpStream` can be split into an `OwnedReadHalf` and a `OwnedWriteHalf`
4 //! with the `TcpStream::into_split` method.  `OwnedReadHalf` implements
5 //! `AsyncRead` while `OwnedWriteHalf` implements `AsyncWrite`.
6 //!
7 //! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized
8 //! split has no associated overhead and enforces all invariants at the type
9 //! level.
10 
11 use crate::future::poll_fn;
12 use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
13 use crate::net::TcpStream;
14 
15 use std::error::Error;
16 use std::net::Shutdown;
17 use std::pin::Pin;
18 use std::sync::Arc;
19 use std::task::{Context, Poll};
20 use std::{fmt, io};
21 
22 /// Owned read half of a [`TcpStream`], created by [`into_split`].
23 ///
24 /// Reading from an `OwnedReadHalf` is usually done using the convenience methods found
25 /// on the [`AsyncReadExt`] trait.
26 ///
27 /// [`TcpStream`]: TcpStream
28 /// [`into_split`]: TcpStream::into_split()
29 /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
30 #[derive(Debug)]
31 pub struct OwnedReadHalf {
32     inner: Arc<TcpStream>,
33 }
34 
35 /// Owned write half of a [`TcpStream`], created by [`into_split`].
36 ///
37 /// Note that in the [`AsyncWrite`] implementation of this type, [`poll_shutdown`] will
38 /// shut down the TCP stream in the write direction.  Dropping the write half
39 /// will also shut down the write half of the TCP stream.
40 ///
41 /// Writing to an `OwnedWriteHalf` is usually done using the convenience methods found
42 /// on the [`AsyncWriteExt`] trait.
43 ///
44 /// [`TcpStream`]: TcpStream
45 /// [`into_split`]: TcpStream::into_split()
46 /// [`AsyncWrite`]: trait@crate::io::AsyncWrite
47 /// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown
48 /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
49 #[derive(Debug)]
50 pub struct OwnedWriteHalf {
51     inner: Arc<TcpStream>,
52     shutdown_on_drop: bool,
53 }
54 
split_owned(stream: TcpStream) -> (OwnedReadHalf, OwnedWriteHalf)55 pub(crate) fn split_owned(stream: TcpStream) -> (OwnedReadHalf, OwnedWriteHalf) {
56     let arc = Arc::new(stream);
57     let read = OwnedReadHalf {
58         inner: Arc::clone(&arc),
59     };
60     let write = OwnedWriteHalf {
61         inner: arc,
62         shutdown_on_drop: true,
63     };
64     (read, write)
65 }
66 
reunite( read: OwnedReadHalf, write: OwnedWriteHalf, ) -> Result<TcpStream, ReuniteError>67 pub(crate) fn reunite(
68     read: OwnedReadHalf,
69     write: OwnedWriteHalf,
70 ) -> Result<TcpStream, ReuniteError> {
71     if Arc::ptr_eq(&read.inner, &write.inner) {
72         write.forget();
73         // This unwrap cannot fail as the api does not allow creating more than two Arcs,
74         // and we just dropped the other half.
75         Ok(Arc::try_unwrap(read.inner).expect("TcpStream: try_unwrap failed in reunite"))
76     } else {
77         Err(ReuniteError(read, write))
78     }
79 }
80 
81 /// Error indicating that two halves were not from the same socket, and thus could
82 /// not be reunited.
83 #[derive(Debug)]
84 pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf);
85 
86 impl fmt::Display for ReuniteError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result87     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88         write!(
89             f,
90             "tried to reunite halves that are not from the same socket"
91         )
92     }
93 }
94 
95 impl Error for ReuniteError {}
96 
97 impl OwnedReadHalf {
98     /// Attempts to put the two halves of a `TcpStream` back together and
99     /// recover the original socket. Succeeds only if the two halves
100     /// originated from the same call to [`into_split`].
101     ///
102     /// [`into_split`]: TcpStream::into_split()
reunite(self, other: OwnedWriteHalf) -> Result<TcpStream, ReuniteError>103     pub fn reunite(self, other: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
104         reunite(self, other)
105     }
106 
107     /// Attempt to receive data on the socket, without removing that data from
108     /// the queue, registering the current task for wakeup if data is not yet
109     /// available.
110     ///
111     /// Note that on multiple calls to `poll_peek` or `poll_read`, only the
112     /// `Waker` from the `Context` passed to the most recent call is scheduled
113     /// to receive a wakeup.
114     ///
115     /// See the [`TcpStream::poll_peek`] level documentation for more details.
116     ///
117     /// # Examples
118     ///
119     /// ```no_run
120     /// use tokio::io::{self, ReadBuf};
121     /// use tokio::net::TcpStream;
122     ///
123     /// use futures::future::poll_fn;
124     ///
125     /// #[tokio::main]
126     /// async fn main() -> io::Result<()> {
127     ///     let stream = TcpStream::connect("127.0.0.1:8000").await?;
128     ///     let (mut read_half, _) = stream.into_split();
129     ///     let mut buf = [0; 10];
130     ///     let mut buf = ReadBuf::new(&mut buf);
131     ///
132     ///     poll_fn(|cx| {
133     ///         read_half.poll_peek(cx, &mut buf)
134     ///     }).await?;
135     ///
136     ///     Ok(())
137     /// }
138     /// ```
139     ///
140     /// [`TcpStream::poll_peek`]: TcpStream::poll_peek
poll_peek( &mut self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<usize>>141     pub fn poll_peek(
142         &mut self,
143         cx: &mut Context<'_>,
144         buf: &mut ReadBuf<'_>,
145     ) -> Poll<io::Result<usize>> {
146         self.inner.poll_peek(cx, buf)
147     }
148 
149     /// Receives data on the socket from the remote address to which it is
150     /// connected, without removing that data from the queue. On success,
151     /// returns the number of bytes peeked.
152     ///
153     /// See the [`TcpStream::peek`] level documentation for more details.
154     ///
155     /// [`TcpStream::peek`]: TcpStream::peek
156     ///
157     /// # Examples
158     ///
159     /// ```no_run
160     /// use tokio::net::TcpStream;
161     /// use tokio::io::AsyncReadExt;
162     /// use std::error::Error;
163     ///
164     /// #[tokio::main]
165     /// async fn main() -> Result<(), Box<dyn Error>> {
166     ///     // Connect to a peer
167     ///     let stream = TcpStream::connect("127.0.0.1:8080").await?;
168     ///     let (mut read_half, _) = stream.into_split();
169     ///
170     ///     let mut b1 = [0; 10];
171     ///     let mut b2 = [0; 10];
172     ///
173     ///     // Peek at the data
174     ///     let n = read_half.peek(&mut b1).await?;
175     ///
176     ///     // Read the data
177     ///     assert_eq!(n, read_half.read(&mut b2[..n]).await?);
178     ///     assert_eq!(&b1[..n], &b2[..n]);
179     ///
180     ///     Ok(())
181     /// }
182     /// ```
183     ///
184     /// The [`read`] method is defined on the [`AsyncReadExt`] trait.
185     ///
186     /// [`read`]: fn@crate::io::AsyncReadExt::read
187     /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
peek(&mut self, buf: &mut [u8]) -> io::Result<usize>188     pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
189         let mut buf = ReadBuf::new(buf);
190         poll_fn(|cx| self.poll_peek(cx, &mut buf)).await
191     }
192 }
193 
194 impl AsyncRead for OwnedReadHalf {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>195     fn poll_read(
196         self: Pin<&mut Self>,
197         cx: &mut Context<'_>,
198         buf: &mut ReadBuf<'_>,
199     ) -> Poll<io::Result<()>> {
200         self.inner.poll_read_priv(cx, buf)
201     }
202 }
203 
204 impl OwnedWriteHalf {
205     /// Attempts to put the two halves of a `TcpStream` back together and
206     /// recover the original socket. Succeeds only if the two halves
207     /// originated from the same call to [`into_split`].
208     ///
209     /// [`into_split`]: TcpStream::into_split()
reunite(self, other: OwnedReadHalf) -> Result<TcpStream, ReuniteError>210     pub fn reunite(self, other: OwnedReadHalf) -> Result<TcpStream, ReuniteError> {
211         reunite(other, self)
212     }
213 
214     /// Destroy the write half, but don't close the write half of the stream
215     /// until the read half is dropped. If the read half has already been
216     /// dropped, this closes the stream.
forget(mut self)217     pub fn forget(mut self) {
218         self.shutdown_on_drop = false;
219         drop(self);
220     }
221 }
222 
223 impl Drop for OwnedWriteHalf {
drop(&mut self)224     fn drop(&mut self) {
225         if self.shutdown_on_drop {
226             let _ = self.inner.shutdown_std(Shutdown::Write);
227         }
228     }
229 }
230 
231 impl AsyncWrite for OwnedWriteHalf {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>232     fn poll_write(
233         self: Pin<&mut Self>,
234         cx: &mut Context<'_>,
235         buf: &[u8],
236     ) -> Poll<io::Result<usize>> {
237         self.inner.poll_write_priv(cx, buf)
238     }
239 
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll<io::Result<usize>>240     fn poll_write_vectored(
241         self: Pin<&mut Self>,
242         cx: &mut Context<'_>,
243         bufs: &[io::IoSlice<'_>],
244     ) -> Poll<io::Result<usize>> {
245         self.inner.poll_write_vectored_priv(cx, bufs)
246     }
247 
is_write_vectored(&self) -> bool248     fn is_write_vectored(&self) -> bool {
249         self.inner.is_write_vectored()
250     }
251 
252     #[inline]
poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>>253     fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
254         // tcp flush is a no-op
255         Poll::Ready(Ok(()))
256     }
257 
258     // `poll_shutdown` on a write half shutdowns the stream in the "write" direction.
poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>>259     fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
260         let res = self.inner.shutdown_std(Shutdown::Write);
261         if res.is_ok() {
262             Pin::into_inner(self).shutdown_on_drop = false;
263         }
264         res.into()
265     }
266 }
267 
268 impl AsRef<TcpStream> for OwnedReadHalf {
as_ref(&self) -> &TcpStream269     fn as_ref(&self) -> &TcpStream {
270         &*self.inner
271     }
272 }
273 
274 impl AsRef<TcpStream> for OwnedWriteHalf {
as_ref(&self) -> &TcpStream275     fn as_ref(&self) -> &TcpStream {
276         &*self.inner
277     }
278 }
279