1 use futures_core::ready;
2 use futures_core::task::{Context, Poll};
3 use futures_io::{AsyncBufRead, AsyncRead, IoSliceMut};
4 use pin_project_lite::pin_project;
5 use std::fmt;
6 use std::io;
7 use std::pin::Pin;
8 
9 pin_project! {
10     /// Reader for the [`chain`](super::AsyncReadExt::chain) method.
11     #[must_use = "readers do nothing unless polled"]
12     pub struct Chain<T, U> {
13         #[pin]
14         first: T,
15         #[pin]
16         second: U,
17         done_first: bool,
18     }
19 }
20 
21 impl<T, U> Chain<T, U>
22 where
23     T: AsyncRead,
24     U: AsyncRead,
25 {
new(first: T, second: U) -> Self26     pub(super) fn new(first: T, second: U) -> Self {
27         Self { first, second, done_first: false }
28     }
29 
30     /// Gets references to the underlying readers in this `Chain`.
get_ref(&self) -> (&T, &U)31     pub fn get_ref(&self) -> (&T, &U) {
32         (&self.first, &self.second)
33     }
34 
35     /// Gets mutable references to the underlying readers in this `Chain`.
36     ///
37     /// Care should be taken to avoid modifying the internal I/O state of the
38     /// underlying readers as doing so may corrupt the internal state of this
39     /// `Chain`.
get_mut(&mut self) -> (&mut T, &mut U)40     pub fn get_mut(&mut self) -> (&mut T, &mut U) {
41         (&mut self.first, &mut self.second)
42     }
43 
44     /// Gets pinned mutable references to the underlying readers in this `Chain`.
45     ///
46     /// Care should be taken to avoid modifying the internal I/O state of the
47     /// underlying readers as doing so may corrupt the internal state of this
48     /// `Chain`.
get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>)49     pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>) {
50         let this = self.project();
51         (this.first, this.second)
52     }
53 
54     /// Consumes the `Chain`, returning the wrapped readers.
into_inner(self) -> (T, U)55     pub fn into_inner(self) -> (T, U) {
56         (self.first, self.second)
57     }
58 }
59 
60 impl<T, U> fmt::Debug for Chain<T, U>
61 where
62     T: fmt::Debug,
63     U: fmt::Debug,
64 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result65     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66         f.debug_struct("Chain")
67             .field("t", &self.first)
68             .field("u", &self.second)
69             .field("done_first", &self.done_first)
70             .finish()
71     }
72 }
73 
74 impl<T, U> AsyncRead for Chain<T, U>
75 where
76     T: AsyncRead,
77     U: AsyncRead,
78 {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll<io::Result<usize>>79     fn poll_read(
80         self: Pin<&mut Self>,
81         cx: &mut Context<'_>,
82         buf: &mut [u8],
83     ) -> Poll<io::Result<usize>> {
84         let this = self.project();
85 
86         if !*this.done_first {
87             match ready!(this.first.poll_read(cx, buf)?) {
88                 0 if !buf.is_empty() => *this.done_first = true,
89                 n => return Poll::Ready(Ok(n)),
90             }
91         }
92         this.second.poll_read(cx, buf)
93     }
94 
poll_read_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>], ) -> Poll<io::Result<usize>>95     fn poll_read_vectored(
96         self: Pin<&mut Self>,
97         cx: &mut Context<'_>,
98         bufs: &mut [IoSliceMut<'_>],
99     ) -> Poll<io::Result<usize>> {
100         let this = self.project();
101 
102         if !*this.done_first {
103             let n = ready!(this.first.poll_read_vectored(cx, bufs)?);
104             if n == 0 && bufs.iter().any(|b| !b.is_empty()) {
105                 *this.done_first = true
106             } else {
107                 return Poll::Ready(Ok(n));
108             }
109         }
110         this.second.poll_read_vectored(cx, bufs)
111     }
112 }
113 
114 impl<T, U> AsyncBufRead for Chain<T, U>
115 where
116     T: AsyncBufRead,
117     U: AsyncBufRead,
118 {
poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>119     fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
120         let this = self.project();
121 
122         if !*this.done_first {
123             match ready!(this.first.poll_fill_buf(cx)?) {
124                 buf if buf.is_empty() => {
125                     *this.done_first = true;
126                 }
127                 buf => return Poll::Ready(Ok(buf)),
128             }
129         }
130         this.second.poll_fill_buf(cx)
131     }
132 
consume(self: Pin<&mut Self>, amt: usize)133     fn consume(self: Pin<&mut Self>, amt: usize) {
134         let this = self.project();
135 
136         if !*this.done_first {
137             this.first.consume(amt)
138         } else {
139             this.second.consume(amt)
140         }
141     }
142 }
143