1 use crate::io::{AsyncBufRead, AsyncRead, ReadBuf};
2 
3 use pin_project_lite::pin_project;
4 use std::fmt;
5 use std::io;
6 use std::pin::Pin;
7 use std::task::{Context, Poll};
8 
9 pin_project! {
10     /// Stream for the [`chain`](super::AsyncReadExt::chain) method.
11     #[must_use = "streams do nothing unless polled"]
12     #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
13     pub struct Chain<T, U> {
14         #[pin]
15         first: T,
16         #[pin]
17         second: U,
18         done_first: bool,
19     }
20 }
21 
chain<T, U>(first: T, second: U) -> Chain<T, U> where T: AsyncRead, U: AsyncRead,22 pub(super) fn chain<T, U>(first: T, second: U) -> Chain<T, U>
23 where
24     T: AsyncRead,
25     U: AsyncRead,
26 {
27     Chain {
28         first,
29         second,
30         done_first: false,
31     }
32 }
33 
34 impl<T, U> Chain<T, U>
35 where
36     T: AsyncRead,
37     U: AsyncRead,
38 {
39     /// Gets references to the underlying readers in this `Chain`.
get_ref(&self) -> (&T, &U)40     pub fn get_ref(&self) -> (&T, &U) {
41         (&self.first, &self.second)
42     }
43 
44     /// Gets mutable references to the underlying readers in this `Chain`.
45     ///
46     /// Care should be taken to avoid modifying the internal I/O state of the
47     /// underlying readers as doing so may corrupt the internal state of this
48     /// `Chain`.
get_mut(&mut self) -> (&mut T, &mut U)49     pub fn get_mut(&mut self) -> (&mut T, &mut U) {
50         (&mut self.first, &mut self.second)
51     }
52 
53     /// Gets pinned mutable references to the underlying readers in this `Chain`.
54     ///
55     /// Care should be taken to avoid modifying the internal I/O state of the
56     /// underlying readers as doing so may corrupt the internal state of this
57     /// `Chain`.
get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>)58     pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>) {
59         let me = self.project();
60         (me.first, me.second)
61     }
62 
63     /// Consumes the `Chain`, returning the wrapped readers.
into_inner(self) -> (T, U)64     pub fn into_inner(self) -> (T, U) {
65         (self.first, self.second)
66     }
67 }
68 
69 impl<T, U> fmt::Debug for Chain<T, U>
70 where
71     T: fmt::Debug,
72     U: fmt::Debug,
73 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result74     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75         f.debug_struct("Chain")
76             .field("t", &self.first)
77             .field("u", &self.second)
78             .finish()
79     }
80 }
81 
82 impl<T, U> AsyncRead for Chain<T, U>
83 where
84     T: AsyncRead,
85     U: AsyncRead,
86 {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>87     fn poll_read(
88         self: Pin<&mut Self>,
89         cx: &mut Context<'_>,
90         buf: &mut ReadBuf<'_>,
91     ) -> Poll<io::Result<()>> {
92         let me = self.project();
93 
94         if !*me.done_first {
95             let rem = buf.remaining();
96             ready!(me.first.poll_read(cx, buf))?;
97             if buf.remaining() == rem {
98                 *me.done_first = true;
99             } else {
100                 return Poll::Ready(Ok(()));
101             }
102         }
103         me.second.poll_read(cx, buf)
104     }
105 }
106 
107 impl<T, U> AsyncBufRead for Chain<T, U>
108 where
109     T: AsyncBufRead,
110     U: AsyncBufRead,
111 {
poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>112     fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
113         let me = self.project();
114 
115         if !*me.done_first {
116             match ready!(me.first.poll_fill_buf(cx)?) {
117                 buf if buf.is_empty() => {
118                     *me.done_first = true;
119                 }
120                 buf => return Poll::Ready(Ok(buf)),
121             }
122         }
123         me.second.poll_fill_buf(cx)
124     }
125 
consume(self: Pin<&mut Self>, amt: usize)126     fn consume(self: Pin<&mut Self>, amt: usize) {
127         let me = self.project();
128         if !*me.done_first {
129             me.first.consume(amt)
130         } else {
131             me.second.consume(amt)
132         }
133     }
134 }
135 
136 #[cfg(test)]
137 mod tests {
138     use super::*;
139 
140     #[test]
assert_unpin()141     fn assert_unpin() {
142         crate::is_unpin::<Chain<(), ()>>();
143     }
144 }
145