1 //! Module defining an Either type. 2 use std::{ 3 future::Future, 4 io::SeekFrom, 5 pin::Pin, 6 task::{Context, Poll}, 7 }; 8 use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf, Result}; 9 10 /// Combines two different futures, streams, or sinks having the same associated types into a single type. 11 /// 12 /// This type implements common asynchronous traits such as [`Future`] and those in Tokio. 13 /// 14 /// [`Future`]: std::future::Future 15 /// 16 /// # Example 17 /// 18 /// The following code will not work: 19 /// 20 /// ```compile_fail 21 /// # fn some_condition() -> bool { true } 22 /// # async fn some_async_function() -> u32 { 10 } 23 /// # async fn other_async_function() -> u32 { 20 } 24 /// #[tokio::main] 25 /// async fn main() { 26 /// let result = if some_condition() { 27 /// some_async_function() 28 /// } else { 29 /// other_async_function() // <- Will print: "`if` and `else` have incompatible types" 30 /// }; 31 /// 32 /// println!("Result is {}", result.await); 33 /// } 34 /// ``` 35 /// 36 // This is because although the output types for both futures is the same, the exact future 37 // types are different, but the compiler must be able to choose a single type for the 38 // `result` variable. 39 /// 40 /// When the output type is the same, we can wrap each future in `Either` to avoid the 41 /// issue: 42 /// 43 /// ``` 44 /// use tokio_util::either::Either; 45 /// # fn some_condition() -> bool { true } 46 /// # async fn some_async_function() -> u32 { 10 } 47 /// # async fn other_async_function() -> u32 { 20 } 48 /// 49 /// #[tokio::main] 50 /// async fn main() { 51 /// let result = if some_condition() { 52 /// Either::Left(some_async_function()) 53 /// } else { 54 /// Either::Right(other_async_function()) 55 /// }; 56 /// 57 /// let value = result.await; 58 /// println!("Result is {}", value); 59 /// # assert_eq!(value, 10); 60 /// } 61 /// ``` 62 #[allow(missing_docs)] // Doc-comments for variants in this particular case don't make much sense. 63 #[derive(Debug, Clone)] 64 pub enum Either<L, R> { 65 Left(L), 66 Right(R), 67 } 68 69 /// A small helper macro which reduces amount of boilerplate in the actual trait method implementation. 70 /// It takes an invokation of method as an argument (e.g. `self.poll(cx)`), and redirects it to either 71 /// enum variant held in `self`. 72 macro_rules! delegate_call { 73 ($self:ident.$method:ident($($args:ident),+)) => { 74 unsafe { 75 match $self.get_unchecked_mut() { 76 Self::Left(l) => Pin::new_unchecked(l).$method($($args),+), 77 Self::Right(r) => Pin::new_unchecked(r).$method($($args),+), 78 } 79 } 80 } 81 } 82 83 impl<L, R, O> Future for Either<L, R> 84 where 85 L: Future<Output = O>, 86 R: Future<Output = O>, 87 { 88 type Output = O; 89 poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>90 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 91 delegate_call!(self.poll(cx)) 92 } 93 } 94 95 impl<L, R> AsyncRead for Either<L, R> 96 where 97 L: AsyncRead, 98 R: AsyncRead, 99 { poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<Result<()>>100 fn poll_read( 101 self: Pin<&mut Self>, 102 cx: &mut Context<'_>, 103 buf: &mut ReadBuf<'_>, 104 ) -> Poll<Result<()>> { 105 delegate_call!(self.poll_read(cx, buf)) 106 } 107 } 108 109 impl<L, R> AsyncBufRead for Either<L, R> 110 where 111 L: AsyncBufRead, 112 R: AsyncBufRead, 113 { poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<&[u8]>>114 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<&[u8]>> { 115 delegate_call!(self.poll_fill_buf(cx)) 116 } 117 consume(self: Pin<&mut Self>, amt: usize)118 fn consume(self: Pin<&mut Self>, amt: usize) { 119 delegate_call!(self.consume(amt)) 120 } 121 } 122 123 impl<L, R> AsyncSeek for Either<L, R> 124 where 125 L: AsyncSeek, 126 R: AsyncSeek, 127 { start_seek(self: Pin<&mut Self>, position: SeekFrom) -> Result<()>128 fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> Result<()> { 129 delegate_call!(self.start_seek(position)) 130 } 131 poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<u64>>132 fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<u64>> { 133 delegate_call!(self.poll_complete(cx)) 134 } 135 } 136 137 impl<L, R> AsyncWrite for Either<L, R> 138 where 139 L: AsyncWrite, 140 R: AsyncWrite, 141 { poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>>142 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> { 143 delegate_call!(self.poll_write(cx, buf)) 144 } 145 poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>>146 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> { 147 delegate_call!(self.poll_flush(cx)) 148 } 149 poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>>150 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> { 151 delegate_call!(self.poll_shutdown(cx)) 152 } 153 } 154 155 impl<L, R> futures_core::stream::Stream for Either<L, R> 156 where 157 L: futures_core::stream::Stream, 158 R: futures_core::stream::Stream<Item = L::Item>, 159 { 160 type Item = L::Item; 161 poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>162 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { 163 delegate_call!(self.poll_next(cx)) 164 } 165 } 166 167 #[cfg(test)] 168 mod tests { 169 use super::*; 170 use tokio::io::{repeat, AsyncReadExt, Repeat}; 171 use tokio_stream::{once, Once, StreamExt}; 172 173 #[tokio::test] either_is_stream()174 async fn either_is_stream() { 175 let mut either: Either<Once<u32>, Once<u32>> = Either::Left(once(1)); 176 177 assert_eq!(Some(1u32), either.next().await); 178 } 179 180 #[tokio::test] either_is_async_read()181 async fn either_is_async_read() { 182 let mut buffer = [0; 3]; 183 let mut either: Either<Repeat, Repeat> = Either::Right(repeat(0b101)); 184 185 either.read_exact(&mut buffer).await.unwrap(); 186 assert_eq!(buffer, [0b101, 0b101, 0b101]); 187 } 188 } 189