1 use futures_core::task::{Context, Poll};
2 #[cfg(feature = "read-initializer")]
3 use futures_io::Initializer;
4 use futures_io::{AsyncRead, AsyncBufRead};
5 use pin_project::{pin_project, project};
6 use std::{cmp, io};
7 use std::pin::Pin;
8 
9 /// Reader for the [`take`](super::AsyncReadExt::take) method.
10 #[pin_project]
11 #[derive(Debug)]
12 #[must_use = "readers do nothing unless you `.await` or poll them"]
13 pub struct Take<R> {
14     #[pin]
15     inner: R,
16     // Add '_' to avoid conflicts with `limit` method.
17     limit_: u64,
18 }
19 
20 impl<R: AsyncRead> Take<R> {
new(inner: R, limit: u64) -> Self21     pub(super) fn new(inner: R, limit: u64) -> Self {
22         Self { inner, limit_: limit }
23     }
24 
25     /// Returns the remaining number of bytes that can be
26     /// read before this instance will return EOF.
27     ///
28     /// # Note
29     ///
30     /// This instance may reach `EOF` after reading fewer bytes than indicated by
31     /// this method if the underlying [`AsyncRead`] instance reaches EOF.
32     ///
33     /// # Examples
34     ///
35     /// ```
36     /// # futures::executor::block_on(async {
37     /// use futures::io::{AsyncReadExt, Cursor};
38     ///
39     /// let reader = Cursor::new(&b"12345678"[..]);
40     /// let mut buffer = [0; 2];
41     ///
42     /// let mut take = reader.take(4);
43     /// let n = take.read(&mut buffer).await?;
44     ///
45     /// assert_eq!(take.limit(), 2);
46     /// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap();
47     /// ```
limit(&self) -> u6448     pub fn limit(&self) -> u64 {
49         self.limit_
50     }
51 
52     /// Sets the number of bytes that can be read before this instance will
53     /// return EOF. This is the same as constructing a new `Take` instance, so
54     /// the amount of bytes read and the previous limit value don't matter when
55     /// calling this method.
56     ///
57     /// # Examples
58     ///
59     /// ```
60     /// # futures::executor::block_on(async {
61     /// use futures::io::{AsyncReadExt, Cursor};
62     ///
63     /// let reader = Cursor::new(&b"12345678"[..]);
64     /// let mut buffer = [0; 4];
65     ///
66     /// let mut take = reader.take(4);
67     /// let n = take.read(&mut buffer).await?;
68     ///
69     /// assert_eq!(n, 4);
70     /// assert_eq!(take.limit(), 0);
71     ///
72     /// take.set_limit(10);
73     /// let n = take.read(&mut buffer).await?;
74     /// assert_eq!(n, 4);
75     ///
76     /// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap();
77     /// ```
set_limit(&mut self, limit: u64)78     pub fn set_limit(&mut self, limit: u64) {
79         self.limit_ = limit
80     }
81 
82     delegate_access_inner!(inner, R, ());
83 }
84 
85 impl<R: AsyncRead> AsyncRead for Take<R> {
86     #[project]
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll<Result<usize, io::Error>>87     fn poll_read(
88         self: Pin<&mut Self>,
89         cx: &mut Context<'_>,
90         buf: &mut [u8],
91     ) -> Poll<Result<usize, io::Error>> {
92         #[project]
93         let Take { inner, limit_ } = self.project();
94 
95         if *limit_ == 0 {
96             return Poll::Ready(Ok(0));
97         }
98 
99         let max = std::cmp::min(buf.len() as u64, *limit_) as usize;
100         let n = ready!(inner.poll_read(cx, &mut buf[..max]))?;
101         *limit_ -= n as u64;
102         Poll::Ready(Ok(n))
103     }
104 
105     #[cfg(feature = "read-initializer")]
initializer(&self) -> Initializer106     unsafe fn initializer(&self) -> Initializer {
107         self.inner.initializer()
108     }
109 }
110 
111 impl<R: AsyncBufRead> AsyncBufRead for Take<R> {
112     #[project]
poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>113     fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
114         #[project]
115         let Take { inner, limit_ } = self.project();
116 
117         // Don't call into inner reader at all at EOF because it may still block
118         if *limit_ == 0 {
119             return Poll::Ready(Ok(&[]));
120         }
121 
122         let buf = ready!(inner.poll_fill_buf(cx)?);
123         let cap = cmp::min(buf.len() as u64, *limit_) as usize;
124         Poll::Ready(Ok(&buf[..cap]))
125     }
126 
127     #[project]
consume(self: Pin<&mut Self>, amt: usize)128     fn consume(self: Pin<&mut Self>, amt: usize) {
129         #[project]
130         let Take { inner, limit_ } = self.project();
131 
132         // Don't let callers reset the limit by passing an overlarge value
133         let amt = cmp::min(amt as u64, *limit_) as usize;
134         *limit_ -= amt as u64;
135         inner.consume(amt);
136     }
137 }
138