1 use bytes::Buf;
2 use futures_core::stream::Stream;
3 use pin_project_lite::pin_project;
4 use std::io;
5 use std::pin::Pin;
6 use std::task::{Context, Poll};
7 use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf};
8 
9 pin_project! {
10     /// Convert a [`Stream`] of byte chunks into an [`AsyncRead`].
11     ///
12     /// This type performs the inverse operation of [`ReaderStream`].
13     ///
14     /// # Example
15     ///
16     /// ```
17     /// use bytes::Bytes;
18     /// use tokio::io::{AsyncReadExt, Result};
19     /// use tokio_util::io::StreamReader;
20     /// # #[tokio::main]
21     /// # async fn main() -> std::io::Result<()> {
22     ///
23     /// // Create a stream from an iterator.
24     /// let stream = tokio_stream::iter(vec![
25     ///     Result::Ok(Bytes::from_static(&[0, 1, 2, 3])),
26     ///     Result::Ok(Bytes::from_static(&[4, 5, 6, 7])),
27     ///     Result::Ok(Bytes::from_static(&[8, 9, 10, 11])),
28     /// ]);
29     ///
30     /// // Convert it to an AsyncRead.
31     /// let mut read = StreamReader::new(stream);
32     ///
33     /// // Read five bytes from the stream.
34     /// let mut buf = [0; 5];
35     /// read.read_exact(&mut buf).await?;
36     /// assert_eq!(buf, [0, 1, 2, 3, 4]);
37     ///
38     /// // Read the rest of the current chunk.
39     /// assert_eq!(read.read(&mut buf).await?, 3);
40     /// assert_eq!(&buf[..3], [5, 6, 7]);
41     ///
42     /// // Read the next chunk.
43     /// assert_eq!(read.read(&mut buf).await?, 4);
44     /// assert_eq!(&buf[..4], [8, 9, 10, 11]);
45     ///
46     /// // We have now reached the end.
47     /// assert_eq!(read.read(&mut buf).await?, 0);
48     ///
49     /// # Ok(())
50     /// # }
51     /// ```
52     ///
53     /// [`AsyncRead`]: tokio::io::AsyncRead
54     /// [`Stream`]: futures_core::Stream
55     /// [`ReaderStream`]: crate::io::ReaderStream
56     #[derive(Debug)]
57     pub struct StreamReader<S, B> {
58         #[pin]
59         inner: S,
60         chunk: Option<B>,
61     }
62 }
63 
64 impl<S, B, E> StreamReader<S, B>
65 where
66     S: Stream<Item = Result<B, E>>,
67     B: Buf,
68     E: Into<std::io::Error>,
69 {
70     /// Convert a stream of byte chunks into an [`AsyncRead`](tokio::io::AsyncRead).
71     ///
72     /// The item should be a [`Result`] with the ok variant being something that
73     /// implements the [`Buf`] trait (e.g. `Vec<u8>` or `Bytes`). The error
74     /// should be convertible into an [io error].
75     ///
76     /// [`Result`]: std::result::Result
77     /// [`Buf`]: bytes::Buf
78     /// [io error]: std::io::Error
new(stream: S) -> Self79     pub fn new(stream: S) -> Self {
80         Self {
81             inner: stream,
82             chunk: None,
83         }
84     }
85 
86     /// Do we have a chunk and is it non-empty?
has_chunk(self: Pin<&mut Self>) -> bool87     fn has_chunk(self: Pin<&mut Self>) -> bool {
88         if let Some(chunk) = self.project().chunk {
89             chunk.remaining() > 0
90         } else {
91             false
92         }
93     }
94 }
95 
96 impl<S, B> StreamReader<S, B> {
97     /// Gets a reference to the underlying stream.
98     ///
99     /// It is inadvisable to directly read from the underlying stream.
get_ref(&self) -> &S100     pub fn get_ref(&self) -> &S {
101         &self.inner
102     }
103 
104     /// Gets a mutable reference to the underlying stream.
105     ///
106     /// It is inadvisable to directly read from the underlying stream.
get_mut(&mut self) -> &mut S107     pub fn get_mut(&mut self) -> &mut S {
108         &mut self.inner
109     }
110 
111     /// Gets a pinned mutable reference to the underlying stream.
112     ///
113     /// It is inadvisable to directly read from the underlying stream.
get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S>114     pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
115         self.project().inner
116     }
117 
118     /// Consumes this `BufWriter`, returning the underlying stream.
119     ///
120     /// Note that any leftover data in the internal buffer is lost.
into_inner(self) -> S121     pub fn into_inner(self) -> S {
122         self.inner
123     }
124 }
125 
126 impl<S, B, E> AsyncRead for StreamReader<S, B>
127 where
128     S: Stream<Item = Result<B, E>>,
129     B: Buf,
130     E: Into<std::io::Error>,
131 {
poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>132     fn poll_read(
133         mut self: Pin<&mut Self>,
134         cx: &mut Context<'_>,
135         buf: &mut ReadBuf<'_>,
136     ) -> Poll<io::Result<()>> {
137         if buf.remaining() == 0 {
138             return Poll::Ready(Ok(()));
139         }
140 
141         let inner_buf = match self.as_mut().poll_fill_buf(cx) {
142             Poll::Ready(Ok(buf)) => buf,
143             Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
144             Poll::Pending => return Poll::Pending,
145         };
146         let len = std::cmp::min(inner_buf.len(), buf.remaining());
147         buf.put_slice(&inner_buf[..len]);
148 
149         self.consume(len);
150         Poll::Ready(Ok(()))
151     }
152 }
153 
154 impl<S, B, E> AsyncBufRead for StreamReader<S, B>
155 where
156     S: Stream<Item = Result<B, E>>,
157     B: Buf,
158     E: Into<std::io::Error>,
159 {
poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>160     fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
161         loop {
162             if self.as_mut().has_chunk() {
163                 // This unwrap is very sad, but it can't be avoided.
164                 let buf = self.project().chunk.as_ref().unwrap().chunk();
165                 return Poll::Ready(Ok(buf));
166             } else {
167                 match self.as_mut().project().inner.poll_next(cx) {
168                     Poll::Ready(Some(Ok(chunk))) => {
169                         // Go around the loop in case the chunk is empty.
170                         *self.as_mut().project().chunk = Some(chunk);
171                     }
172                     Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())),
173                     Poll::Ready(None) => return Poll::Ready(Ok(&[])),
174                     Poll::Pending => return Poll::Pending,
175                 }
176             }
177         }
178     }
consume(self: Pin<&mut Self>, amt: usize)179     fn consume(self: Pin<&mut Self>, amt: usize) {
180         if amt > 0 {
181             self.project()
182                 .chunk
183                 .as_mut()
184                 .expect("No chunk present")
185                 .advance(amt);
186         }
187     }
188 }
189