1 use crate::stream_ext::Fuse; 2 use crate::Stream; 3 use tokio::time::{Instant, Sleep}; 4 5 use core::future::Future; 6 use core::pin::Pin; 7 use core::task::{Context, Poll}; 8 use pin_project_lite::pin_project; 9 use std::fmt; 10 use std::time::Duration; 11 12 pin_project! { 13 /// Stream returned by the [`timeout`](super::StreamExt::timeout) method. 14 #[must_use = "streams do nothing unless polled"] 15 #[derive(Debug)] 16 pub struct Timeout<S> { 17 #[pin] 18 stream: Fuse<S>, 19 #[pin] 20 deadline: Sleep, 21 duration: Duration, 22 poll_deadline: bool, 23 } 24 } 25 26 /// Error returned by `Timeout`. 27 #[derive(Debug, PartialEq)] 28 pub struct Elapsed(()); 29 30 impl<S: Stream> Timeout<S> { new(stream: S, duration: Duration) -> Self31 pub(super) fn new(stream: S, duration: Duration) -> Self { 32 let next = Instant::now() + duration; 33 let deadline = tokio::time::sleep_until(next); 34 35 Timeout { 36 stream: Fuse::new(stream), 37 deadline, 38 duration, 39 poll_deadline: true, 40 } 41 } 42 } 43 44 impl<S: Stream> Stream for Timeout<S> { 45 type Item = Result<S::Item, Elapsed>; 46 poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>47 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { 48 let me = self.project(); 49 50 match me.stream.poll_next(cx) { 51 Poll::Ready(v) => { 52 if v.is_some() { 53 let next = Instant::now() + *me.duration; 54 me.deadline.reset(next); 55 *me.poll_deadline = true; 56 } 57 return Poll::Ready(v.map(Ok)); 58 } 59 Poll::Pending => {} 60 }; 61 62 if *me.poll_deadline { 63 ready!(me.deadline.poll(cx)); 64 *me.poll_deadline = false; 65 return Poll::Ready(Some(Err(Elapsed::new()))); 66 } 67 68 Poll::Pending 69 } 70 size_hint(&self) -> (usize, Option<usize>)71 fn size_hint(&self) -> (usize, Option<usize>) { 72 let (lower, upper) = self.stream.size_hint(); 73 74 // The timeout stream may insert an error before and after each message 75 // from the underlying stream, but no more than one error between each 76 // message. Hence the upper bound is computed as 2x+1. 77 78 // Using a helper function to enable use of question mark operator. 79 fn twice_plus_one(value: Option<usize>) -> Option<usize> { 80 value?.checked_mul(2)?.checked_add(1) 81 } 82 83 (lower, twice_plus_one(upper)) 84 } 85 } 86 87 // ===== impl Elapsed ===== 88 89 impl Elapsed { new() -> Self90 pub(crate) fn new() -> Self { 91 Elapsed(()) 92 } 93 } 94 95 impl fmt::Display for Elapsed { fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result96 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 97 "deadline has elapsed".fmt(fmt) 98 } 99 } 100 101 impl std::error::Error for Elapsed {} 102 103 impl From<Elapsed> for std::io::Error { from(_err: Elapsed) -> std::io::Error104 fn from(_err: Elapsed) -> std::io::Error { 105 std::io::ErrorKind::TimedOut.into() 106 } 107 } 108