1 use crate::stream_ext::Fuse;
2 use crate::Stream;
3 use tokio::time::{Instant, Sleep};
4 
5 use core::future::Future;
6 use core::pin::Pin;
7 use core::task::{Context, Poll};
8 use pin_project_lite::pin_project;
9 use std::fmt;
10 use std::time::Duration;
11 
12 pin_project! {
13     /// Stream returned by the [`timeout`](super::StreamExt::timeout) method.
14     #[must_use = "streams do nothing unless polled"]
15     #[derive(Debug)]
16     pub struct Timeout<S> {
17         #[pin]
18         stream: Fuse<S>,
19         #[pin]
20         deadline: Sleep,
21         duration: Duration,
22         poll_deadline: bool,
23     }
24 }
25 
26 /// Error returned by `Timeout`.
27 #[derive(Debug, PartialEq)]
28 pub struct Elapsed(());
29 
30 impl<S: Stream> Timeout<S> {
new(stream: S, duration: Duration) -> Self31     pub(super) fn new(stream: S, duration: Duration) -> Self {
32         let next = Instant::now() + duration;
33         let deadline = tokio::time::sleep_until(next);
34 
35         Timeout {
36             stream: Fuse::new(stream),
37             deadline,
38             duration,
39             poll_deadline: true,
40         }
41     }
42 }
43 
44 impl<S: Stream> Stream for Timeout<S> {
45     type Item = Result<S::Item, Elapsed>;
46 
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>47     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
48         let me = self.project();
49 
50         match me.stream.poll_next(cx) {
51             Poll::Ready(v) => {
52                 if v.is_some() {
53                     let next = Instant::now() + *me.duration;
54                     me.deadline.reset(next);
55                     *me.poll_deadline = true;
56                 }
57                 return Poll::Ready(v.map(Ok));
58             }
59             Poll::Pending => {}
60         };
61 
62         if *me.poll_deadline {
63             ready!(me.deadline.poll(cx));
64             *me.poll_deadline = false;
65             return Poll::Ready(Some(Err(Elapsed::new())));
66         }
67 
68         Poll::Pending
69     }
70 
size_hint(&self) -> (usize, Option<usize>)71     fn size_hint(&self) -> (usize, Option<usize>) {
72         let (lower, upper) = self.stream.size_hint();
73 
74         // The timeout stream may insert an error before and after each message
75         // from the underlying stream, but no more than one error between each
76         // message. Hence the upper bound is computed as 2x+1.
77 
78         // Using a helper function to enable use of question mark operator.
79         fn twice_plus_one(value: Option<usize>) -> Option<usize> {
80             value?.checked_mul(2)?.checked_add(1)
81         }
82 
83         (lower, twice_plus_one(upper))
84     }
85 }
86 
87 // ===== impl Elapsed =====
88 
89 impl Elapsed {
new() -> Self90     pub(crate) fn new() -> Self {
91         Elapsed(())
92     }
93 }
94 
95 impl fmt::Display for Elapsed {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result96     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
97         "deadline has elapsed".fmt(fmt)
98     }
99 }
100 
101 impl std::error::Error for Elapsed {}
102 
103 impl From<Elapsed> for std::io::Error {
from(_err: Elapsed) -> std::io::Error104     fn from(_err: Elapsed) -> std::io::Error {
105         std::io::ErrorKind::TimedOut.into()
106     }
107 }
108