1 use crate::io::{AsyncBufRead, AsyncRead, ReadBuf};
2
3 use pin_project_lite::pin_project;
4 use std::fmt;
5 use std::io;
6 use std::pin::Pin;
7 use std::task::{Context, Poll};
8
9 pin_project! {
10 /// Stream for the [`chain`](super::AsyncReadExt::chain) method.
11 #[must_use = "streams do nothing unless polled"]
12 #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
13 pub struct Chain<T, U> {
14 #[pin]
15 first: T,
16 #[pin]
17 second: U,
18 done_first: bool,
19 }
20 }
21
chain<T, U>(first: T, second: U) -> Chain<T, U> where T: AsyncRead, U: AsyncRead,22 pub(super) fn chain<T, U>(first: T, second: U) -> Chain<T, U>
23 where
24 T: AsyncRead,
25 U: AsyncRead,
26 {
27 Chain {
28 first,
29 second,
30 done_first: false,
31 }
32 }
33
34 impl<T, U> Chain<T, U>
35 where
36 T: AsyncRead,
37 U: AsyncRead,
38 {
39 /// Gets references to the underlying readers in this `Chain`.
get_ref(&self) -> (&T, &U)40 pub fn get_ref(&self) -> (&T, &U) {
41 (&self.first, &self.second)
42 }
43
44 /// Gets mutable references to the underlying readers in this `Chain`.
45 ///
46 /// Care should be taken to avoid modifying the internal I/O state of the
47 /// underlying readers as doing so may corrupt the internal state of this
48 /// `Chain`.
get_mut(&mut self) -> (&mut T, &mut U)49 pub fn get_mut(&mut self) -> (&mut T, &mut U) {
50 (&mut self.first, &mut self.second)
51 }
52
53 /// Gets pinned mutable references to the underlying readers in this `Chain`.
54 ///
55 /// Care should be taken to avoid modifying the internal I/O state of the
56 /// underlying readers as doing so may corrupt the internal state of this
57 /// `Chain`.
get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>)58 pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>) {
59 let me = self.project();
60 (me.first, me.second)
61 }
62
63 /// Consumes the `Chain`, returning the wrapped readers.
into_inner(self) -> (T, U)64 pub fn into_inner(self) -> (T, U) {
65 (self.first, self.second)
66 }
67 }
68
69 impl<T, U> fmt::Debug for Chain<T, U>
70 where
71 T: fmt::Debug,
72 U: fmt::Debug,
73 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 f.debug_struct("Chain")
76 .field("t", &self.first)
77 .field("u", &self.second)
78 .finish()
79 }
80 }
81
82 impl<T, U> AsyncRead for Chain<T, U>
83 where
84 T: AsyncRead,
85 U: AsyncRead,
86 {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>87 fn poll_read(
88 self: Pin<&mut Self>,
89 cx: &mut Context<'_>,
90 buf: &mut ReadBuf<'_>,
91 ) -> Poll<io::Result<()>> {
92 let me = self.project();
93
94 if !*me.done_first {
95 let rem = buf.remaining();
96 ready!(me.first.poll_read(cx, buf))?;
97 if buf.remaining() == rem {
98 *me.done_first = true;
99 } else {
100 return Poll::Ready(Ok(()));
101 }
102 }
103 me.second.poll_read(cx, buf)
104 }
105 }
106
107 impl<T, U> AsyncBufRead for Chain<T, U>
108 where
109 T: AsyncBufRead,
110 U: AsyncBufRead,
111 {
poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>112 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
113 let me = self.project();
114
115 if !*me.done_first {
116 match ready!(me.first.poll_fill_buf(cx)?) {
117 buf if buf.is_empty() => {
118 *me.done_first = true;
119 }
120 buf => return Poll::Ready(Ok(buf)),
121 }
122 }
123 me.second.poll_fill_buf(cx)
124 }
125
consume(self: Pin<&mut Self>, amt: usize)126 fn consume(self: Pin<&mut Self>, amt: usize) {
127 let me = self.project();
128 if !*me.done_first {
129 me.first.consume(amt)
130 } else {
131 me.second.consume(amt)
132 }
133 }
134 }
135
136 #[cfg(test)]
137 mod tests {
138 use super::*;
139
140 #[test]
assert_unpin()141 fn assert_unpin() {
142 crate::is_unpin::<Chain<(), ()>>();
143 }
144 }
145