1 use core::pin::Pin;
2 use futures_core::future::TryFuture;
3 use futures_core::task::{Context, Poll};
4 
5 #[must_use = "futures do nothing unless you `.await` or poll them"]
6 #[derive(Debug)]
7 pub(crate) enum TryChain<Fut1, Fut2, Data> {
8     First(Fut1, Option<Data>),
9     Second(Fut2),
10     Empty,
11 }
12 
13 impl<Fut1: Unpin, Fut2: Unpin, Data> Unpin for TryChain<Fut1, Fut2, Data> {}
14 
15 pub(crate) enum TryChainAction<Fut2>
16     where Fut2: TryFuture,
17 {
18     Future(Fut2),
19     Output(Result<Fut2::Ok, Fut2::Error>),
20 }
21 
22 impl<Fut1, Fut2, Data> TryChain<Fut1, Fut2, Data>
23     where Fut1: TryFuture,
24           Fut2: TryFuture,
25 {
new(fut1: Fut1, data: Data) -> TryChain<Fut1, Fut2, Data>26     pub(crate) fn new(fut1: Fut1, data: Data) -> TryChain<Fut1, Fut2, Data> {
27         TryChain::First(fut1, Some(data))
28     }
29 
is_terminated(&self) -> bool30     pub(crate) fn is_terminated(&self) -> bool {
31         match self {
32             TryChain::First(..) | TryChain::Second(_) => false,
33             TryChain::Empty => true,
34         }
35     }
36 
poll<F>( self: Pin<&mut Self>, cx: &mut Context<'_>, f: F, ) -> Poll<Result<Fut2::Ok, Fut2::Error>> where F: FnOnce(Result<Fut1::Ok, Fut1::Error>, Data) -> TryChainAction<Fut2>,37     pub(crate) fn poll<F>(
38         self: Pin<&mut Self>,
39         cx: &mut Context<'_>,
40         f: F,
41     ) -> Poll<Result<Fut2::Ok, Fut2::Error>>
42         where F: FnOnce(Result<Fut1::Ok, Fut1::Error>, Data) -> TryChainAction<Fut2>,
43     {
44         let mut f = Some(f);
45 
46         // Safe to call `get_unchecked_mut` because we won't move the futures.
47         let this = unsafe { self.get_unchecked_mut() };
48 
49         loop {
50             let (output, data) = match this {
51                 TryChain::First(fut1, data) => {
52                     // Poll the first future
53                     let output = ready!(unsafe { Pin::new_unchecked(fut1) }.try_poll(cx));
54                     (output, data.take().unwrap())
55                 }
56                 TryChain::Second(fut2) => {
57                     // Poll the second future
58                     return unsafe { Pin::new_unchecked(fut2) }
59                         .try_poll(cx)
60                         .map(|res| {
61                             *this = TryChain::Empty; // Drop fut2.
62                             res
63                         });
64                 }
65                 TryChain::Empty => {
66                     panic!("future must not be polled after it returned `Poll::Ready`");
67                 }
68             };
69 
70             *this = TryChain::Empty; // Drop fut1
71             let f = f.take().unwrap();
72             match f(output, data) {
73                 TryChainAction::Future(fut2) => *this = TryChain::Second(fut2),
74                 TryChainAction::Output(output) => return Poll::Ready(output),
75             }
76         }
77     }
78 }
79 
80 #[cfg(test)]
81 mod tests {
82     use std::pin::Pin;
83     use std::task::Poll;
84 
85     use futures_test::task::noop_context;
86 
87     use crate::future::ready;
88 
89     use super::{TryChain, TryChainAction};
90 
91     #[test]
try_chain_is_terminated()92     fn try_chain_is_terminated() {
93         let mut cx = noop_context();
94 
95         let mut future = TryChain::new(ready(Ok(1)), ());
96         assert!(!future.is_terminated());
97 
98         let res = Pin::new(&mut future).poll(
99             &mut cx,
100             |res: Result<usize, ()>, ()| {
101                 assert!(res.is_ok());
102                 TryChainAction::Future(ready(Ok(2)))
103             },
104         );
105         assert_eq!(res, Poll::Ready::<Result<usize, ()>>(Ok(2)));
106         assert!(future.is_terminated());
107     }
108 }
109