1 //! Module defining an Either type.
2 use std::{
3     future::Future,
4     io::SeekFrom,
5     pin::Pin,
6     task::{Context, Poll},
7 };
8 use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf, Result};
9 
10 /// Combines two different futures, streams, or sinks having the same associated types into a single type.
11 ///
12 /// This type implements common asynchronous traits such as [`Future`] and those in Tokio.
13 ///
14 /// [`Future`]: std::future::Future
15 ///
16 /// # Example
17 ///
18 /// The following code will not work:
19 ///
20 /// ```compile_fail
21 /// # fn some_condition() -> bool { true }
22 /// # async fn some_async_function() -> u32 { 10 }
23 /// # async fn other_async_function() -> u32 { 20 }
24 /// #[tokio::main]
25 /// async fn main() {
26 ///     let result = if some_condition() {
27 ///         some_async_function()
28 ///     } else {
29 ///         other_async_function() // <- Will print: "`if` and `else` have incompatible types"
30 ///     };
31 ///
32 ///     println!("Result is {}", result.await);
33 /// }
34 /// ```
35 ///
36 // This is because although the output types for both futures is the same, the exact future
37 // types are different, but the compiler must be able to choose a single type for the
38 // `result` variable.
39 ///
40 /// When the output type is the same, we can wrap each future in `Either` to avoid the
41 /// issue:
42 ///
43 /// ```
44 /// use tokio_util::either::Either;
45 /// # fn some_condition() -> bool { true }
46 /// # async fn some_async_function() -> u32 { 10 }
47 /// # async fn other_async_function() -> u32 { 20 }
48 ///
49 /// #[tokio::main]
50 /// async fn main() {
51 ///     let result = if some_condition() {
52 ///         Either::Left(some_async_function())
53 ///     } else {
54 ///         Either::Right(other_async_function())
55 ///     };
56 ///
57 ///     let value = result.await;
58 ///     println!("Result is {}", value);
59 ///     # assert_eq!(value, 10);
60 /// }
61 /// ```
62 #[allow(missing_docs)] // Doc-comments for variants in this particular case don't make much sense.
63 #[derive(Debug, Clone)]
64 pub enum Either<L, R> {
65     Left(L),
66     Right(R),
67 }
68 
69 /// A small helper macro which reduces amount of boilerplate in the actual trait method implementation.
70 /// It takes an invokation of method as an argument (e.g. `self.poll(cx)`), and redirects it to either
71 /// enum variant held in `self`.
72 macro_rules! delegate_call {
73     ($self:ident.$method:ident($($args:ident),+)) => {
74         unsafe {
75             match $self.get_unchecked_mut() {
76                 Self::Left(l) => Pin::new_unchecked(l).$method($($args),+),
77                 Self::Right(r) => Pin::new_unchecked(r).$method($($args),+),
78             }
79         }
80     }
81 }
82 
83 impl<L, R, O> Future for Either<L, R>
84 where
85     L: Future<Output = O>,
86     R: Future<Output = O>,
87 {
88     type Output = O;
89 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>90     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
91         delegate_call!(self.poll(cx))
92     }
93 }
94 
95 impl<L, R> AsyncRead for Either<L, R>
96 where
97     L: AsyncRead,
98     R: AsyncRead,
99 {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<Result<()>>100     fn poll_read(
101         self: Pin<&mut Self>,
102         cx: &mut Context<'_>,
103         buf: &mut ReadBuf<'_>,
104     ) -> Poll<Result<()>> {
105         delegate_call!(self.poll_read(cx, buf))
106     }
107 }
108 
109 impl<L, R> AsyncBufRead for Either<L, R>
110 where
111     L: AsyncBufRead,
112     R: AsyncBufRead,
113 {
poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<&[u8]>>114     fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<&[u8]>> {
115         delegate_call!(self.poll_fill_buf(cx))
116     }
117 
consume(self: Pin<&mut Self>, amt: usize)118     fn consume(self: Pin<&mut Self>, amt: usize) {
119         delegate_call!(self.consume(amt))
120     }
121 }
122 
123 impl<L, R> AsyncSeek for Either<L, R>
124 where
125     L: AsyncSeek,
126     R: AsyncSeek,
127 {
start_seek(self: Pin<&mut Self>, position: SeekFrom) -> Result<()>128     fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> Result<()> {
129         delegate_call!(self.start_seek(position))
130     }
131 
poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<u64>>132     fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<u64>> {
133         delegate_call!(self.poll_complete(cx))
134     }
135 }
136 
137 impl<L, R> AsyncWrite for Either<L, R>
138 where
139     L: AsyncWrite,
140     R: AsyncWrite,
141 {
poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>>142     fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
143         delegate_call!(self.poll_write(cx, buf))
144     }
145 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>>146     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
147         delegate_call!(self.poll_flush(cx))
148     }
149 
poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>>150     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
151         delegate_call!(self.poll_shutdown(cx))
152     }
153 }
154 
155 impl<L, R> futures_core::stream::Stream for Either<L, R>
156 where
157     L: futures_core::stream::Stream,
158     R: futures_core::stream::Stream<Item = L::Item>,
159 {
160     type Item = L::Item;
161 
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>162     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
163         delegate_call!(self.poll_next(cx))
164     }
165 }
166 
167 #[cfg(test)]
168 mod tests {
169     use super::*;
170     use tokio::io::{repeat, AsyncReadExt, Repeat};
171     use tokio_stream::{once, Once, StreamExt};
172 
173     #[tokio::test]
either_is_stream()174     async fn either_is_stream() {
175         let mut either: Either<Once<u32>, Once<u32>> = Either::Left(once(1));
176 
177         assert_eq!(Some(1u32), either.next().await);
178     }
179 
180     #[tokio::test]
either_is_async_read()181     async fn either_is_async_read() {
182         let mut buffer = [0; 3];
183         let mut either: Either<Repeat, Repeat> = Either::Right(repeat(0b101));
184 
185         either.read_exact(&mut buffer).await.unwrap();
186         assert_eq!(buffer, [0b101, 0b101, 0b101]);
187     }
188 }
189