1 use std::mem;
2
3 use pin_project::pin_project;
4 use tokio::sync::{mpsc, watch};
5
6 use super::{task, Future, Never, Pin, Poll};
7
8 // Sentinel value signaling that the watch is still open
9 #[derive(Clone, Copy)]
10 enum Action {
11 Open,
12 // Closed isn't sent via the `Action` type, but rather once
13 // the watch::Sender is dropped.
14 }
15
channel() -> (Signal, Watch)16 pub fn channel() -> (Signal, Watch) {
17 let (tx, rx) = watch::channel(Action::Open);
18 let (drained_tx, drained_rx) = mpsc::channel(1);
19 (
20 Signal {
21 drained_rx,
22 _tx: tx,
23 },
24 Watch { drained_tx, rx },
25 )
26 }
27
28 pub struct Signal {
29 drained_rx: mpsc::Receiver<Never>,
30 _tx: watch::Sender<Action>,
31 }
32
33 pub struct Draining {
34 drained_rx: mpsc::Receiver<Never>,
35 }
36
37 #[derive(Clone)]
38 pub struct Watch {
39 drained_tx: mpsc::Sender<Never>,
40 rx: watch::Receiver<Action>,
41 }
42
43 #[allow(missing_debug_implementations)]
44 #[pin_project]
45 pub struct Watching<F, FN> {
46 #[pin]
47 future: F,
48 state: State<FN>,
49 watch: Watch,
50 }
51
52 enum State<F> {
53 Watch(F),
54 Draining,
55 }
56
57 impl Signal {
drain(self) -> Draining58 pub fn drain(self) -> Draining {
59 // Simply dropping `self.tx` will signal the watchers
60 Draining {
61 drained_rx: self.drained_rx,
62 }
63 }
64 }
65
66 impl Future for Draining {
67 type Output = ();
68
poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output>69 fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
70 match ready!(self.drained_rx.poll_recv(cx)) {
71 Some(never) => match never {},
72 None => Poll::Ready(()),
73 }
74 }
75 }
76
77 impl Watch {
watch<F, FN>(self, future: F, on_drain: FN) -> Watching<F, FN> where F: Future, FN: FnOnce(Pin<&mut F>),78 pub fn watch<F, FN>(self, future: F, on_drain: FN) -> Watching<F, FN>
79 where
80 F: Future,
81 FN: FnOnce(Pin<&mut F>),
82 {
83 Watching {
84 future,
85 state: State::Watch(on_drain),
86 watch: self,
87 }
88 }
89 }
90
91 impl<F, FN> Future for Watching<F, FN>
92 where
93 F: Future,
94 FN: FnOnce(Pin<&mut F>),
95 {
96 type Output = F::Output;
97
poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output>98 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
99 let mut me = self.project();
100 loop {
101 match mem::replace(me.state, State::Draining) {
102 State::Watch(on_drain) => {
103 match me.watch.rx.poll_recv_ref(cx) {
104 Poll::Ready(None) => {
105 // Drain has been triggered!
106 on_drain(me.future.as_mut());
107 }
108 Poll::Ready(Some(_ /*State::Open*/)) | Poll::Pending => {
109 *me.state = State::Watch(on_drain);
110 return me.future.poll(cx);
111 }
112 }
113 }
114 State::Draining => return me.future.poll(cx),
115 }
116 }
117 }
118 }
119
120 #[cfg(test)]
121 mod tests {
122 use super::*;
123
124 struct TestMe {
125 draining: bool,
126 finished: bool,
127 poll_cnt: usize,
128 }
129
130 impl Future for TestMe {
131 type Output = ();
132
poll(mut self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output>133 fn poll(mut self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll<Self::Output> {
134 self.poll_cnt += 1;
135 if self.finished {
136 Poll::Ready(())
137 } else {
138 Poll::Pending
139 }
140 }
141 }
142
143 #[test]
watch()144 fn watch() {
145 let mut mock = tokio_test::task::spawn(());
146 mock.enter(|cx, _| {
147 let (tx, rx) = channel();
148 let fut = TestMe {
149 draining: false,
150 finished: false,
151 poll_cnt: 0,
152 };
153
154 let mut watch = rx.watch(fut, |mut fut| {
155 fut.draining = true;
156 });
157
158 assert_eq!(watch.future.poll_cnt, 0);
159
160 // First poll should poll the inner future
161 assert!(Pin::new(&mut watch).poll(cx).is_pending());
162 assert_eq!(watch.future.poll_cnt, 1);
163
164 // Second poll should poll the inner future again
165 assert!(Pin::new(&mut watch).poll(cx).is_pending());
166 assert_eq!(watch.future.poll_cnt, 2);
167
168 let mut draining = tx.drain();
169 // Drain signaled, but needs another poll to be noticed.
170 assert!(!watch.future.draining);
171 assert_eq!(watch.future.poll_cnt, 2);
172
173 // Now, poll after drain has been signaled.
174 assert!(Pin::new(&mut watch).poll(cx).is_pending());
175 assert_eq!(watch.future.poll_cnt, 3);
176 assert!(watch.future.draining);
177
178 // Draining is not ready until watcher completes
179 assert!(Pin::new(&mut draining).poll(cx).is_pending());
180
181 // Finishing up the watch future
182 watch.future.finished = true;
183 assert!(Pin::new(&mut watch).poll(cx).is_ready());
184 assert_eq!(watch.future.poll_cnt, 4);
185 drop(watch);
186
187 assert!(Pin::new(&mut draining).poll(cx).is_ready());
188 })
189 }
190
191 #[test]
watch_clones()192 fn watch_clones() {
193 let mut mock = tokio_test::task::spawn(());
194 mock.enter(|cx, _| {
195 let (tx, rx) = channel();
196
197 let fut1 = TestMe {
198 draining: false,
199 finished: false,
200 poll_cnt: 0,
201 };
202 let fut2 = TestMe {
203 draining: false,
204 finished: false,
205 poll_cnt: 0,
206 };
207
208 let watch1 = rx.clone().watch(fut1, |mut fut| {
209 fut.draining = true;
210 });
211 let watch2 = rx.watch(fut2, |mut fut| {
212 fut.draining = true;
213 });
214
215 let mut draining = tx.drain();
216
217 // Still 2 outstanding watchers
218 assert!(Pin::new(&mut draining).poll(cx).is_pending());
219
220 // drop 1 for whatever reason
221 drop(watch1);
222
223 // Still not ready, 1 other watcher still pending
224 assert!(Pin::new(&mut draining).poll(cx).is_pending());
225
226 drop(watch2);
227
228 // Now all watchers are gone, draining is complete
229 assert!(Pin::new(&mut draining).poll(cx).is_ready());
230 });
231 }
232 }
233