1 use BufStream; 2 3 use bytes::Buf; 4 use futures::Poll; 5 6 /// Limits the stream to a maximum amount of data. 7 #[derive(Debug)] 8 pub struct Limit<T> { 9 stream: T, 10 remaining: u64, 11 } 12 13 /// Errors returned from `Limit`. 14 #[derive(Debug)] 15 pub struct LimitError<T> { 16 /// When `None`, limit was reached 17 inner: Option<T>, 18 } 19 20 impl<T> Limit<T> { new(stream: T, amount: u64) -> Limit<T>21 pub(crate) fn new(stream: T, amount: u64) -> Limit<T> { 22 Limit { 23 stream, 24 remaining: amount, 25 } 26 } 27 } 28 29 impl<T> BufStream for Limit<T> 30 where 31 T: BufStream, 32 { 33 type Item = T::Item; 34 type Error = LimitError<T::Error>; 35 poll_buf(&mut self) -> Poll<Option<Self::Item>, Self::Error>36 fn poll_buf(&mut self) -> Poll<Option<Self::Item>, Self::Error> { 37 use futures::Async::Ready; 38 39 if self.stream.size_hint().lower() > self.remaining { 40 return Err(LimitError { inner: None }); 41 } 42 43 let res = self 44 .stream 45 .poll_buf() 46 .map_err(|err| LimitError { inner: Some(err) }); 47 48 match res { 49 Ok(Ready(Some(ref buf))) => { 50 if buf.remaining() as u64 > self.remaining { 51 self.remaining = 0; 52 return Err(LimitError { inner: None }); 53 } 54 55 self.remaining -= buf.remaining() as u64; 56 } 57 _ => {} 58 } 59 60 res 61 } 62 } 63 64 // ===== impl LimitError ===== 65 66 impl<T> LimitError<T> { 67 /// Returns `true` if the error was caused by polling the stream. is_stream_err(&self) -> bool68 pub fn is_stream_err(&self) -> bool { 69 self.inner.is_some() 70 } 71 72 /// Returns `true` if the stream reached its limit. is_limit_err(&self) -> bool73 pub fn is_limit_err(&self) -> bool { 74 self.inner.is_none() 75 } 76 } 77