1 use futures_core::task::{Context, Poll};
2 #[cfg(feature = "read-initializer")]
3 use futures_io::Initializer;
4 use futures_io::{AsyncBufRead, AsyncRead, IoSliceMut};
5 use pin_project::{pin_project, project};
6 use std::fmt;
7 use std::io;
8 use std::pin::Pin;
9 
10 /// Reader for the [`chain`](super::AsyncReadExt::chain) method.
11 #[pin_project]
12 #[must_use = "readers do nothing unless polled"]
13 pub struct Chain<T, U> {
14     #[pin]
15     first: T,
16     #[pin]
17     second: U,
18     done_first: bool,
19 }
20 
21 impl<T, U> Chain<T, U>
22 where
23     T: AsyncRead,
24     U: AsyncRead,
25 {
new(first: T, second: U) -> Self26     pub(super) fn new(first: T, second: U) -> Self {
27         Self {
28             first,
29             second,
30             done_first: false,
31         }
32     }
33 
34     /// Gets references to the underlying readers in this `Chain`.
get_ref(&self) -> (&T, &U)35     pub fn get_ref(&self) -> (&T, &U) {
36         (&self.first, &self.second)
37     }
38 
39     /// Gets mutable references to the underlying readers in this `Chain`.
40     ///
41     /// Care should be taken to avoid modifying the internal I/O state of the
42     /// underlying readers as doing so may corrupt the internal state of this
43     /// `Chain`.
get_mut(&mut self) -> (&mut T, &mut U)44     pub fn get_mut(&mut self) -> (&mut T, &mut U) {
45         (&mut self.first, &mut self.second)
46     }
47 
48     /// Gets pinned mutable references to the underlying readers in this `Chain`.
49     ///
50     /// Care should be taken to avoid modifying the internal I/O state of the
51     /// underlying readers as doing so may corrupt the internal state of this
52     /// `Chain`.
get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>)53     pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>) {
54         unsafe {
55             let Self { first, second, .. } = self.get_unchecked_mut();
56             (Pin::new_unchecked(first), Pin::new_unchecked(second))
57         }
58     }
59 
60     /// Consumes the `Chain`, returning the wrapped readers.
into_inner(self) -> (T, U)61     pub fn into_inner(self) -> (T, U) {
62         (self.first, self.second)
63     }
64 }
65 
66 impl<T, U> fmt::Debug for Chain<T, U>
67 where
68     T: fmt::Debug,
69     U: fmt::Debug,
70 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result71     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72         f.debug_struct("Chain")
73             .field("t", &self.first)
74             .field("u", &self.second)
75             .field("done_first", &self.done_first)
76             .finish()
77     }
78 }
79 
80 impl<T, U> AsyncRead for Chain<T, U>
81 where
82     T: AsyncRead,
83     U: AsyncRead,
84 {
85     #[project]
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll<io::Result<usize>>86     fn poll_read(
87         self: Pin<&mut Self>,
88         cx: &mut Context<'_>,
89         buf: &mut [u8],
90     ) -> Poll<io::Result<usize>> {
91         #[project]
92         let Chain { first, second, done_first } = self.project();
93 
94         if !*done_first {
95             match ready!(first.poll_read(cx, buf)?) {
96                 0 if !buf.is_empty() => *done_first = true,
97                 n => return Poll::Ready(Ok(n)),
98             }
99         }
100         second.poll_read(cx, buf)
101     }
102 
103     #[project]
poll_read_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>], ) -> Poll<io::Result<usize>>104     fn poll_read_vectored(
105         self: Pin<&mut Self>,
106         cx: &mut Context<'_>,
107         bufs: &mut [IoSliceMut<'_>],
108     ) -> Poll<io::Result<usize>> {
109         #[project]
110         let Chain { first, second, done_first } = self.project();
111 
112         if !*done_first {
113             let n = ready!(first.poll_read_vectored(cx, bufs)?);
114             if n == 0 && bufs.iter().any(|b| !b.is_empty()) {
115                 *done_first = true
116             } else {
117                 return Poll::Ready(Ok(n));
118             }
119         }
120         second.poll_read_vectored(cx, bufs)
121     }
122 
123     #[cfg(feature = "read-initializer")]
initializer(&self) -> Initializer124     unsafe fn initializer(&self) -> Initializer {
125         let initializer = self.first.initializer();
126         if initializer.should_initialize() {
127             initializer
128         } else {
129             self.second.initializer()
130         }
131     }
132 }
133 
134 impl<T, U> AsyncBufRead for Chain<T, U>
135 where
136     T: AsyncBufRead,
137     U: AsyncBufRead,
138 {
139     #[project]
poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>140     fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
141         #[project]
142         let Chain { first, second, done_first } = self.project();
143 
144         if !*done_first {
145             match ready!(first.poll_fill_buf(cx)?) {
146                 buf if buf.is_empty() => {
147                     *done_first = true;
148                 }
149                 buf => return Poll::Ready(Ok(buf)),
150             }
151         }
152         second.poll_fill_buf(cx)
153     }
154 
155     #[project]
consume(self: Pin<&mut Self>, amt: usize)156     fn consume(self: Pin<&mut Self>, amt: usize) {
157         #[project]
158         let Chain { first, second, done_first } = self.project();
159 
160         if !*done_first {
161             first.consume(amt)
162         } else {
163             second.consume(amt)
164         }
165     }
166 }
167