1 use BufStream;
2 
3 use bytes::Buf;
4 use futures::Poll;
5 
6 /// Limits the stream to a maximum amount of data.
7 #[derive(Debug)]
8 pub struct Limit<T> {
9     stream: T,
10     remaining: u64,
11 }
12 
13 /// Errors returned from `Limit`.
14 #[derive(Debug)]
15 pub struct LimitError<T> {
16     /// When `None`, limit was reached
17     inner: Option<T>,
18 }
19 
20 impl<T> Limit<T> {
new(stream: T, amount: u64) -> Limit<T>21     pub(crate) fn new(stream: T, amount: u64) -> Limit<T> {
22         Limit {
23             stream,
24             remaining: amount,
25         }
26     }
27 }
28 
29 impl<T> BufStream for Limit<T>
30 where
31     T: BufStream,
32 {
33     type Item = T::Item;
34     type Error = LimitError<T::Error>;
35 
poll_buf(&mut self) -> Poll<Option<Self::Item>, Self::Error>36     fn poll_buf(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
37         use futures::Async::Ready;
38 
39         if self.stream.size_hint().lower() > self.remaining {
40             return Err(LimitError { inner: None });
41         }
42 
43         let res = self
44             .stream
45             .poll_buf()
46             .map_err(|err| LimitError { inner: Some(err) });
47 
48         match res {
49             Ok(Ready(Some(ref buf))) => {
50                 if buf.remaining() as u64 > self.remaining {
51                     self.remaining = 0;
52                     return Err(LimitError { inner: None });
53                 }
54 
55                 self.remaining -= buf.remaining() as u64;
56             }
57             _ => {}
58         }
59 
60         res
61     }
62 }
63 
64 // ===== impl LimitError =====
65 
66 impl<T> LimitError<T> {
67     /// Returns `true` if the error was caused by polling the stream.
is_stream_err(&self) -> bool68     pub fn is_stream_err(&self) -> bool {
69         self.inner.is_some()
70     }
71 
72     /// Returns `true` if the stream reached its limit.
is_limit_err(&self) -> bool73     pub fn is_limit_err(&self) -> bool {
74         self.inner.is_none()
75     }
76 }
77