1 use futures_core::ready;
2 use futures_core::task::{Context, Poll};
3 #[cfg(feature = "read-initializer")]
4 use futures_io::Initializer;
5 use futures_io::{AsyncBufRead, AsyncRead, IoSliceMut};
6 use pin_project_lite::pin_project;
7 use std::fmt;
8 use std::io;
9 use std::pin::Pin;
10 
11 pin_project! {
12     /// Reader for the [`chain`](super::AsyncReadExt::chain) method.
13     #[must_use = "readers do nothing unless polled"]
14     pub struct Chain<T, U> {
15         #[pin]
16         first: T,
17         #[pin]
18         second: U,
19         done_first: bool,
20     }
21 }
22 
23 impl<T, U> Chain<T, U>
24 where
25     T: AsyncRead,
26     U: AsyncRead,
27 {
new(first: T, second: U) -> Self28     pub(super) fn new(first: T, second: U) -> Self {
29         Self {
30             first,
31             second,
32             done_first: false,
33         }
34     }
35 
36     /// Gets references to the underlying readers in this `Chain`.
get_ref(&self) -> (&T, &U)37     pub fn get_ref(&self) -> (&T, &U) {
38         (&self.first, &self.second)
39     }
40 
41     /// Gets mutable references to the underlying readers in this `Chain`.
42     ///
43     /// Care should be taken to avoid modifying the internal I/O state of the
44     /// underlying readers as doing so may corrupt the internal state of this
45     /// `Chain`.
get_mut(&mut self) -> (&mut T, &mut U)46     pub fn get_mut(&mut self) -> (&mut T, &mut U) {
47         (&mut self.first, &mut self.second)
48     }
49 
50     /// Gets pinned mutable references to the underlying readers in this `Chain`.
51     ///
52     /// Care should be taken to avoid modifying the internal I/O state of the
53     /// underlying readers as doing so may corrupt the internal state of this
54     /// `Chain`.
get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>)55     pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>) {
56         let this = self.project();
57         (this.first, this.second)
58     }
59 
60     /// Consumes the `Chain`, returning the wrapped readers.
into_inner(self) -> (T, U)61     pub fn into_inner(self) -> (T, U) {
62         (self.first, self.second)
63     }
64 }
65 
66 impl<T, U> fmt::Debug for Chain<T, U>
67 where
68     T: fmt::Debug,
69     U: fmt::Debug,
70 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result71     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72         f.debug_struct("Chain")
73             .field("t", &self.first)
74             .field("u", &self.second)
75             .field("done_first", &self.done_first)
76             .finish()
77     }
78 }
79 
80 impl<T, U> AsyncRead for Chain<T, U>
81 where
82     T: AsyncRead,
83     U: AsyncRead,
84 {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll<io::Result<usize>>85     fn poll_read(
86         self: Pin<&mut Self>,
87         cx: &mut Context<'_>,
88         buf: &mut [u8],
89     ) -> Poll<io::Result<usize>> {
90         let this = self.project();
91 
92         if !*this.done_first {
93             match ready!(this.first.poll_read(cx, buf)?) {
94                 0 if !buf.is_empty() => *this.done_first = true,
95                 n => return Poll::Ready(Ok(n)),
96             }
97         }
98         this.second.poll_read(cx, buf)
99     }
100 
poll_read_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>], ) -> Poll<io::Result<usize>>101     fn poll_read_vectored(
102         self: Pin<&mut Self>,
103         cx: &mut Context<'_>,
104         bufs: &mut [IoSliceMut<'_>],
105     ) -> Poll<io::Result<usize>> {
106         let this = self.project();
107 
108         if !*this.done_first {
109             let n = ready!(this.first.poll_read_vectored(cx, bufs)?);
110             if n == 0 && bufs.iter().any(|b| !b.is_empty()) {
111                 *this.done_first = true
112             } else {
113                 return Poll::Ready(Ok(n));
114             }
115         }
116         this.second.poll_read_vectored(cx, bufs)
117     }
118 
119     #[cfg(feature = "read-initializer")]
initializer(&self) -> Initializer120     unsafe fn initializer(&self) -> Initializer {
121         let initializer = self.first.initializer();
122         if initializer.should_initialize() {
123             initializer
124         } else {
125             self.second.initializer()
126         }
127     }
128 }
129 
130 impl<T, U> AsyncBufRead for Chain<T, U>
131 where
132     T: AsyncBufRead,
133     U: AsyncBufRead,
134 {
poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>135     fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
136         let this = self.project();
137 
138         if !*this.done_first {
139             match ready!(this.first.poll_fill_buf(cx)?) {
140                 buf if buf.is_empty() => {
141                     *this.done_first = true;
142                 }
143                 buf => return Poll::Ready(Ok(buf)),
144             }
145         }
146         this.second.poll_fill_buf(cx)
147     }
148 
consume(self: Pin<&mut Self>, amt: usize)149     fn consume(self: Pin<&mut Self>, amt: usize) {
150         let this = self.project();
151 
152         if !*this.done_first {
153             this.first.consume(amt)
154         } else {
155             this.second.consume(amt)
156         }
157     }
158 }
159