1 use bytes::Buf; 2 use futures_core::stream::Stream; 3 use pin_project_lite::pin_project; 4 use std::io; 5 use std::pin::Pin; 6 use std::task::{Context, Poll}; 7 use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf}; 8 9 pin_project! { 10 /// Convert a [`Stream`] of byte chunks into an [`AsyncRead`]. 11 /// 12 /// This type performs the inverse operation of [`ReaderStream`]. 13 /// 14 /// # Example 15 /// 16 /// ``` 17 /// use bytes::Bytes; 18 /// use tokio::io::{AsyncReadExt, Result}; 19 /// use tokio_util::io::StreamReader; 20 /// # #[tokio::main] 21 /// # async fn main() -> std::io::Result<()> { 22 /// 23 /// // Create a stream from an iterator. 24 /// let stream = tokio_stream::iter(vec![ 25 /// Result::Ok(Bytes::from_static(&[0, 1, 2, 3])), 26 /// Result::Ok(Bytes::from_static(&[4, 5, 6, 7])), 27 /// Result::Ok(Bytes::from_static(&[8, 9, 10, 11])), 28 /// ]); 29 /// 30 /// // Convert it to an AsyncRead. 31 /// let mut read = StreamReader::new(stream); 32 /// 33 /// // Read five bytes from the stream. 34 /// let mut buf = [0; 5]; 35 /// read.read_exact(&mut buf).await?; 36 /// assert_eq!(buf, [0, 1, 2, 3, 4]); 37 /// 38 /// // Read the rest of the current chunk. 39 /// assert_eq!(read.read(&mut buf).await?, 3); 40 /// assert_eq!(&buf[..3], [5, 6, 7]); 41 /// 42 /// // Read the next chunk. 43 /// assert_eq!(read.read(&mut buf).await?, 4); 44 /// assert_eq!(&buf[..4], [8, 9, 10, 11]); 45 /// 46 /// // We have now reached the end. 47 /// assert_eq!(read.read(&mut buf).await?, 0); 48 /// 49 /// # Ok(()) 50 /// # } 51 /// ``` 52 /// 53 /// [`AsyncRead`]: tokio::io::AsyncRead 54 /// [`Stream`]: futures_core::Stream 55 /// [`ReaderStream`]: crate::io::ReaderStream 56 #[derive(Debug)] 57 pub struct StreamReader<S, B> { 58 #[pin] 59 inner: S, 60 chunk: Option<B>, 61 } 62 } 63 64 impl<S, B, E> StreamReader<S, B> 65 where 66 S: Stream<Item = Result<B, E>>, 67 B: Buf, 68 E: Into<std::io::Error>, 69 { 70 /// Convert a stream of byte chunks into an [`AsyncRead`](tokio::io::AsyncRead). 71 /// 72 /// The item should be a [`Result`] with the ok variant being something that 73 /// implements the [`Buf`] trait (e.g. `Vec<u8>` or `Bytes`). The error 74 /// should be convertible into an [io error]. 75 /// 76 /// [`Result`]: std::result::Result 77 /// [`Buf`]: bytes::Buf 78 /// [io error]: std::io::Error new(stream: S) -> Self79 pub fn new(stream: S) -> Self { 80 Self { 81 inner: stream, 82 chunk: None, 83 } 84 } 85 86 /// Do we have a chunk and is it non-empty? has_chunk(self: Pin<&mut Self>) -> bool87 fn has_chunk(self: Pin<&mut Self>) -> bool { 88 if let Some(chunk) = self.project().chunk { 89 chunk.remaining() > 0 90 } else { 91 false 92 } 93 } 94 } 95 96 impl<S, B> StreamReader<S, B> { 97 /// Gets a reference to the underlying stream. 98 /// 99 /// It is inadvisable to directly read from the underlying stream. get_ref(&self) -> &S100 pub fn get_ref(&self) -> &S { 101 &self.inner 102 } 103 104 /// Gets a mutable reference to the underlying stream. 105 /// 106 /// It is inadvisable to directly read from the underlying stream. get_mut(&mut self) -> &mut S107 pub fn get_mut(&mut self) -> &mut S { 108 &mut self.inner 109 } 110 111 /// Gets a pinned mutable reference to the underlying stream. 112 /// 113 /// It is inadvisable to directly read from the underlying stream. get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S>114 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> { 115 self.project().inner 116 } 117 118 /// Consumes this `BufWriter`, returning the underlying stream. 119 /// 120 /// Note that any leftover data in the internal buffer is lost. into_inner(self) -> S121 pub fn into_inner(self) -> S { 122 self.inner 123 } 124 } 125 126 impl<S, B, E> AsyncRead for StreamReader<S, B> 127 where 128 S: Stream<Item = Result<B, E>>, 129 B: Buf, 130 E: Into<std::io::Error>, 131 { poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>132 fn poll_read( 133 mut self: Pin<&mut Self>, 134 cx: &mut Context<'_>, 135 buf: &mut ReadBuf<'_>, 136 ) -> Poll<io::Result<()>> { 137 if buf.remaining() == 0 { 138 return Poll::Ready(Ok(())); 139 } 140 141 let inner_buf = match self.as_mut().poll_fill_buf(cx) { 142 Poll::Ready(Ok(buf)) => buf, 143 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), 144 Poll::Pending => return Poll::Pending, 145 }; 146 let len = std::cmp::min(inner_buf.len(), buf.remaining()); 147 buf.put_slice(&inner_buf[..len]); 148 149 self.consume(len); 150 Poll::Ready(Ok(())) 151 } 152 } 153 154 impl<S, B, E> AsyncBufRead for StreamReader<S, B> 155 where 156 S: Stream<Item = Result<B, E>>, 157 B: Buf, 158 E: Into<std::io::Error>, 159 { poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>160 fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { 161 loop { 162 if self.as_mut().has_chunk() { 163 // This unwrap is very sad, but it can't be avoided. 164 let buf = self.project().chunk.as_ref().unwrap().chunk(); 165 return Poll::Ready(Ok(buf)); 166 } else { 167 match self.as_mut().project().inner.poll_next(cx) { 168 Poll::Ready(Some(Ok(chunk))) => { 169 // Go around the loop in case the chunk is empty. 170 *self.as_mut().project().chunk = Some(chunk); 171 } 172 Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())), 173 Poll::Ready(None) => return Poll::Ready(Ok(&[])), 174 Poll::Pending => return Poll::Pending, 175 } 176 } 177 } 178 } consume(self: Pin<&mut Self>, amt: usize)179 fn consume(self: Pin<&mut Self>, amt: usize) { 180 if amt > 0 { 181 self.project() 182 .chunk 183 .as_mut() 184 .expect("No chunk present") 185 .advance(amt); 186 } 187 } 188 } 189