1 use crate::io::{AsyncBufRead, AsyncRead};
2 
3 use pin_project_lite::pin_project;
4 use std::fmt;
5 use std::io;
6 use std::pin::Pin;
7 use std::task::{Context, Poll};
8 
9 pin_project! {
10     /// Stream for the [`chain`](super::AsyncReadExt::chain) method.
11     #[must_use = "streams do nothing unless polled"]
12     #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
13     pub struct Chain<T, U> {
14         #[pin]
15         first: T,
16         #[pin]
17         second: U,
18         done_first: bool,
19     }
20 }
21 
chain<T, U>(first: T, second: U) -> Chain<T, U> where T: AsyncRead, U: AsyncRead,22 pub(super) fn chain<T, U>(first: T, second: U) -> Chain<T, U>
23 where
24     T: AsyncRead,
25     U: AsyncRead,
26 {
27     Chain {
28         first,
29         second,
30         done_first: false,
31     }
32 }
33 
34 impl<T, U> Chain<T, U>
35 where
36     T: AsyncRead,
37     U: AsyncRead,
38 {
39     /// Gets references to the underlying readers in this `Chain`.
get_ref(&self) -> (&T, &U)40     pub fn get_ref(&self) -> (&T, &U) {
41         (&self.first, &self.second)
42     }
43 
44     /// Gets mutable references to the underlying readers in this `Chain`.
45     ///
46     /// Care should be taken to avoid modifying the internal I/O state of the
47     /// underlying readers as doing so may corrupt the internal state of this
48     /// `Chain`.
get_mut(&mut self) -> (&mut T, &mut U)49     pub fn get_mut(&mut self) -> (&mut T, &mut U) {
50         (&mut self.first, &mut self.second)
51     }
52 
53     /// Gets pinned mutable references to the underlying readers in this `Chain`.
54     ///
55     /// Care should be taken to avoid modifying the internal I/O state of the
56     /// underlying readers as doing so may corrupt the internal state of this
57     /// `Chain`.
get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>)58     pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>) {
59         let me = self.project();
60         (me.first, me.second)
61     }
62 
63     /// Consumes the `Chain`, returning the wrapped readers.
into_inner(self) -> (T, U)64     pub fn into_inner(self) -> (T, U) {
65         (self.first, self.second)
66     }
67 }
68 
69 impl<T, U> fmt::Debug for Chain<T, U>
70 where
71     T: fmt::Debug,
72     U: fmt::Debug,
73 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result74     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75         f.debug_struct("Chain")
76             .field("t", &self.first)
77             .field("u", &self.second)
78             .finish()
79     }
80 }
81 
82 impl<T, U> AsyncRead for Chain<T, U>
83 where
84     T: AsyncRead,
85     U: AsyncRead,
86 {
prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool87     unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
88         if self.first.prepare_uninitialized_buffer(buf) {
89             return true;
90         }
91         if self.second.prepare_uninitialized_buffer(buf) {
92             return true;
93         }
94         false
95     }
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll<io::Result<usize>>96     fn poll_read(
97         self: Pin<&mut Self>,
98         cx: &mut Context<'_>,
99         buf: &mut [u8],
100     ) -> Poll<io::Result<usize>> {
101         let me = self.project();
102 
103         if !*me.done_first {
104             match ready!(me.first.poll_read(cx, buf)?) {
105                 0 if !buf.is_empty() => *me.done_first = true,
106                 n => return Poll::Ready(Ok(n)),
107             }
108         }
109         me.second.poll_read(cx, buf)
110     }
111 }
112 
113 impl<T, U> AsyncBufRead for Chain<T, U>
114 where
115     T: AsyncBufRead,
116     U: AsyncBufRead,
117 {
poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>118     fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
119         let me = self.project();
120 
121         if !*me.done_first {
122             match ready!(me.first.poll_fill_buf(cx)?) {
123                 buf if buf.is_empty() => {
124                     *me.done_first = true;
125                 }
126                 buf => return Poll::Ready(Ok(buf)),
127             }
128         }
129         me.second.poll_fill_buf(cx)
130     }
131 
consume(self: Pin<&mut Self>, amt: usize)132     fn consume(self: Pin<&mut Self>, amt: usize) {
133         let me = self.project();
134         if !*me.done_first {
135             me.first.consume(amt)
136         } else {
137             me.second.consume(amt)
138         }
139     }
140 }
141 
142 #[cfg(test)]
143 mod tests {
144     use super::*;
145 
146     #[test]
assert_unpin()147     fn assert_unpin() {
148         crate::is_unpin::<Chain<(), ()>>();
149     }
150 }
151