1 use std::mem;
2 
3 use pin_project::pin_project;
4 use tokio::sync::{mpsc, watch};
5 
6 use super::{task, Future, Never, Pin, Poll};
7 
8 // Sentinel value signaling that the watch is still open
9 #[derive(Clone, Copy)]
10 enum Action {
11     Open,
12     // Closed isn't sent via the `Action` type, but rather once
13     // the watch::Sender is dropped.
14 }
15 
channel() -> (Signal, Watch)16 pub fn channel() -> (Signal, Watch) {
17     let (tx, rx) = watch::channel(Action::Open);
18     let (drained_tx, drained_rx) = mpsc::channel(1);
19     (
20         Signal {
21             drained_rx,
22             _tx: tx,
23         },
24         Watch { drained_tx, rx },
25     )
26 }
27 
28 pub struct Signal {
29     drained_rx: mpsc::Receiver<Never>,
30     _tx: watch::Sender<Action>,
31 }
32 
33 pub struct Draining {
34     drained_rx: mpsc::Receiver<Never>,
35 }
36 
37 #[derive(Clone)]
38 pub struct Watch {
39     drained_tx: mpsc::Sender<Never>,
40     rx: watch::Receiver<Action>,
41 }
42 
43 #[allow(missing_debug_implementations)]
44 #[pin_project]
45 pub struct Watching<F, FN> {
46     #[pin]
47     future: F,
48     state: State<FN>,
49     watch: Watch,
50 }
51 
52 enum State<F> {
53     Watch(F),
54     Draining,
55 }
56 
57 impl Signal {
drain(self) -> Draining58     pub fn drain(self) -> Draining {
59         // Simply dropping `self.tx` will signal the watchers
60         Draining {
61             drained_rx: self.drained_rx,
62         }
63     }
64 }
65 
66 impl Future for Draining {
67     type Output = ();
68 
poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output>69     fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
70         match ready!(self.drained_rx.poll_recv(cx)) {
71             Some(never) => match never {},
72             None => Poll::Ready(()),
73         }
74     }
75 }
76 
77 impl Watch {
watch<F, FN>(self, future: F, on_drain: FN) -> Watching<F, FN> where F: Future, FN: FnOnce(Pin<&mut F>),78     pub fn watch<F, FN>(self, future: F, on_drain: FN) -> Watching<F, FN>
79     where
80         F: Future,
81         FN: FnOnce(Pin<&mut F>),
82     {
83         Watching {
84             future,
85             state: State::Watch(on_drain),
86             watch: self,
87         }
88     }
89 }
90 
91 impl<F, FN> Future for Watching<F, FN>
92 where
93     F: Future,
94     FN: FnOnce(Pin<&mut F>),
95 {
96     type Output = F::Output;
97 
poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output>98     fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
99         let mut me = self.project();
100         loop {
101             match mem::replace(me.state, State::Draining) {
102                 State::Watch(on_drain) => {
103                     match me.watch.rx.poll_recv_ref(cx) {
104                         Poll::Ready(None) => {
105                             // Drain has been triggered!
106                             on_drain(me.future.as_mut());
107                         }
108                         Poll::Ready(Some(_ /*State::Open*/)) | Poll::Pending => {
109                             *me.state = State::Watch(on_drain);
110                             return me.future.poll(cx);
111                         }
112                     }
113                 }
114                 State::Draining => return me.future.poll(cx),
115             }
116         }
117     }
118 }
119 
120 #[cfg(test)]
121 mod tests {
122     use super::*;
123 
124     struct TestMe {
125         draining: bool,
126         finished: bool,
127         poll_cnt: usize,
128     }
129 
130     impl Future for TestMe {
131         type Output = ();
132 
poll(mut self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output>133         fn poll(mut self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
134             self.poll_cnt += 1;
135             if self.finished {
136                 Poll::Ready(())
137             } else {
138                 Poll::Pending
139             }
140         }
141     }
142 
143     #[test]
watch()144     fn watch() {
145         let mut mock = tokio_test::task::spawn(());
146         mock.enter(|cx, _| {
147             let (tx, rx) = channel();
148             let fut = TestMe {
149                 draining: false,
150                 finished: false,
151                 poll_cnt: 0,
152             };
153 
154             let mut watch = rx.watch(fut, |mut fut| {
155                 fut.draining = true;
156             });
157 
158             assert_eq!(watch.future.poll_cnt, 0);
159 
160             // First poll should poll the inner future
161             assert!(Pin::new(&mut watch).poll(cx).is_pending());
162             assert_eq!(watch.future.poll_cnt, 1);
163 
164             // Second poll should poll the inner future again
165             assert!(Pin::new(&mut watch).poll(cx).is_pending());
166             assert_eq!(watch.future.poll_cnt, 2);
167 
168             let mut draining = tx.drain();
169             // Drain signaled, but needs another poll to be noticed.
170             assert!(!watch.future.draining);
171             assert_eq!(watch.future.poll_cnt, 2);
172 
173             // Now, poll after drain has been signaled.
174             assert!(Pin::new(&mut watch).poll(cx).is_pending());
175             assert_eq!(watch.future.poll_cnt, 3);
176             assert!(watch.future.draining);
177 
178             // Draining is not ready until watcher completes
179             assert!(Pin::new(&mut draining).poll(cx).is_pending());
180 
181             // Finishing up the watch future
182             watch.future.finished = true;
183             assert!(Pin::new(&mut watch).poll(cx).is_ready());
184             assert_eq!(watch.future.poll_cnt, 4);
185             drop(watch);
186 
187             assert!(Pin::new(&mut draining).poll(cx).is_ready());
188         })
189     }
190 
191     #[test]
watch_clones()192     fn watch_clones() {
193         let mut mock = tokio_test::task::spawn(());
194         mock.enter(|cx, _| {
195             let (tx, rx) = channel();
196 
197             let fut1 = TestMe {
198                 draining: false,
199                 finished: false,
200                 poll_cnt: 0,
201             };
202             let fut2 = TestMe {
203                 draining: false,
204                 finished: false,
205                 poll_cnt: 0,
206             };
207 
208             let watch1 = rx.clone().watch(fut1, |mut fut| {
209                 fut.draining = true;
210             });
211             let watch2 = rx.watch(fut2, |mut fut| {
212                 fut.draining = true;
213             });
214 
215             let mut draining = tx.drain();
216 
217             // Still 2 outstanding watchers
218             assert!(Pin::new(&mut draining).poll(cx).is_pending());
219 
220             // drop 1 for whatever reason
221             drop(watch1);
222 
223             // Still not ready, 1 other watcher still pending
224             assert!(Pin::new(&mut draining).poll(cx).is_pending());
225 
226             drop(watch2);
227 
228             // Now all watchers are gone, draining is complete
229             assert!(Pin::new(&mut draining).poll(cx).is_ready());
230         });
231     }
232 }
233