1 use crate::io::sys;
2 use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
3 
4 use std::cmp;
5 use std::future::Future;
6 use std::io;
7 use std::io::prelude::*;
8 use std::pin::Pin;
9 use std::task::Poll::*;
10 use std::task::{Context, Poll};
11 
12 use self::State::*;
13 
14 /// `T` should not implement _both_ Read and Write.
15 #[derive(Debug)]
16 pub(crate) struct Blocking<T> {
17     inner: Option<T>,
18     state: State<T>,
19     /// `true` if the lower IO layer needs flushing.
20     need_flush: bool,
21 }
22 
23 #[derive(Debug)]
24 pub(crate) struct Buf {
25     buf: Vec<u8>,
26     pos: usize,
27 }
28 
29 pub(crate) const MAX_BUF: usize = 16 * 1024;
30 
31 #[derive(Debug)]
32 enum State<T> {
33     Idle(Option<Buf>),
34     Busy(sys::Blocking<(io::Result<usize>, Buf, T)>),
35 }
36 
37 cfg_io_std! {
38     impl<T> Blocking<T> {
39         pub(crate) fn new(inner: T) -> Blocking<T> {
40             Blocking {
41                 inner: Some(inner),
42                 state: State::Idle(Some(Buf::with_capacity(0))),
43                 need_flush: false,
44             }
45         }
46     }
47 }
48 
49 impl<T> AsyncRead for Blocking<T>
50 where
51     T: Read + Unpin + Send + 'static,
52 {
poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, dst: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>53     fn poll_read(
54         mut self: Pin<&mut Self>,
55         cx: &mut Context<'_>,
56         dst: &mut ReadBuf<'_>,
57     ) -> Poll<io::Result<()>> {
58         loop {
59             match self.state {
60                 Idle(ref mut buf_cell) => {
61                     let mut buf = buf_cell.take().unwrap();
62 
63                     if !buf.is_empty() {
64                         buf.copy_to(dst);
65                         *buf_cell = Some(buf);
66                         return Ready(Ok(()));
67                     }
68 
69                     buf.ensure_capacity_for(dst);
70                     let mut inner = self.inner.take().unwrap();
71 
72                     self.state = Busy(sys::run(move || {
73                         let res = buf.read_from(&mut inner);
74                         (res, buf, inner)
75                     }));
76                 }
77                 Busy(ref mut rx) => {
78                     let (res, mut buf, inner) = ready!(Pin::new(rx).poll(cx))?;
79                     self.inner = Some(inner);
80 
81                     match res {
82                         Ok(_) => {
83                             buf.copy_to(dst);
84                             self.state = Idle(Some(buf));
85                             return Ready(Ok(()));
86                         }
87                         Err(e) => {
88                             assert!(buf.is_empty());
89 
90                             self.state = Idle(Some(buf));
91                             return Ready(Err(e));
92                         }
93                     }
94                 }
95             }
96         }
97     }
98 }
99 
100 impl<T> AsyncWrite for Blocking<T>
101 where
102     T: Write + Unpin + Send + 'static,
103 {
poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, src: &[u8], ) -> Poll<io::Result<usize>>104     fn poll_write(
105         mut self: Pin<&mut Self>,
106         cx: &mut Context<'_>,
107         src: &[u8],
108     ) -> Poll<io::Result<usize>> {
109         loop {
110             match self.state {
111                 Idle(ref mut buf_cell) => {
112                     let mut buf = buf_cell.take().unwrap();
113 
114                     assert!(buf.is_empty());
115 
116                     let n = buf.copy_from(src);
117                     let mut inner = self.inner.take().unwrap();
118 
119                     self.state = Busy(sys::run(move || {
120                         let n = buf.len();
121                         let res = buf.write_to(&mut inner).map(|_| n);
122 
123                         (res, buf, inner)
124                     }));
125                     self.need_flush = true;
126 
127                     return Ready(Ok(n));
128                 }
129                 Busy(ref mut rx) => {
130                     let (res, buf, inner) = ready!(Pin::new(rx).poll(cx))?;
131                     self.state = Idle(Some(buf));
132                     self.inner = Some(inner);
133 
134                     // If error, return
135                     res?;
136                 }
137             }
138         }
139     }
140 
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>141     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
142         loop {
143             let need_flush = self.need_flush;
144             match self.state {
145                 // The buffer is not used here
146                 Idle(ref mut buf_cell) => {
147                     if need_flush {
148                         let buf = buf_cell.take().unwrap();
149                         let mut inner = self.inner.take().unwrap();
150 
151                         self.state = Busy(sys::run(move || {
152                             let res = inner.flush().map(|_| 0);
153                             (res, buf, inner)
154                         }));
155 
156                         self.need_flush = false;
157                     } else {
158                         return Ready(Ok(()));
159                     }
160                 }
161                 Busy(ref mut rx) => {
162                     let (res, buf, inner) = ready!(Pin::new(rx).poll(cx))?;
163                     self.state = Idle(Some(buf));
164                     self.inner = Some(inner);
165 
166                     // If error, return
167                     res?;
168                 }
169             }
170         }
171     }
172 
poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>173     fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
174         Poll::Ready(Ok(()))
175     }
176 }
177 
178 /// Repeats operations that are interrupted.
179 macro_rules! uninterruptibly {
180     ($e:expr) => {{
181         loop {
182             match $e {
183                 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
184                 res => break res,
185             }
186         }
187     }};
188 }
189 
190 impl Buf {
with_capacity(n: usize) -> Buf191     pub(crate) fn with_capacity(n: usize) -> Buf {
192         Buf {
193             buf: Vec::with_capacity(n),
194             pos: 0,
195         }
196     }
197 
is_empty(&self) -> bool198     pub(crate) fn is_empty(&self) -> bool {
199         self.len() == 0
200     }
201 
len(&self) -> usize202     pub(crate) fn len(&self) -> usize {
203         self.buf.len() - self.pos
204     }
205 
copy_to(&mut self, dst: &mut ReadBuf<'_>) -> usize206     pub(crate) fn copy_to(&mut self, dst: &mut ReadBuf<'_>) -> usize {
207         let n = cmp::min(self.len(), dst.remaining());
208         dst.put_slice(&self.bytes()[..n]);
209         self.pos += n;
210 
211         if self.pos == self.buf.len() {
212             self.buf.truncate(0);
213             self.pos = 0;
214         }
215 
216         n
217     }
218 
copy_from(&mut self, src: &[u8]) -> usize219     pub(crate) fn copy_from(&mut self, src: &[u8]) -> usize {
220         assert!(self.is_empty());
221 
222         let n = cmp::min(src.len(), MAX_BUF);
223 
224         self.buf.extend_from_slice(&src[..n]);
225         n
226     }
227 
bytes(&self) -> &[u8]228     pub(crate) fn bytes(&self) -> &[u8] {
229         &self.buf[self.pos..]
230     }
231 
ensure_capacity_for(&mut self, bytes: &ReadBuf<'_>)232     pub(crate) fn ensure_capacity_for(&mut self, bytes: &ReadBuf<'_>) {
233         assert!(self.is_empty());
234 
235         let len = cmp::min(bytes.remaining(), MAX_BUF);
236 
237         if self.buf.len() < len {
238             self.buf.reserve(len - self.buf.len());
239         }
240 
241         unsafe {
242             self.buf.set_len(len);
243         }
244     }
245 
read_from<T: Read>(&mut self, rd: &mut T) -> io::Result<usize>246     pub(crate) fn read_from<T: Read>(&mut self, rd: &mut T) -> io::Result<usize> {
247         let res = uninterruptibly!(rd.read(&mut self.buf));
248 
249         if let Ok(n) = res {
250             self.buf.truncate(n);
251         } else {
252             self.buf.clear();
253         }
254 
255         assert_eq!(self.pos, 0);
256 
257         res
258     }
259 
write_to<T: Write>(&mut self, wr: &mut T) -> io::Result<()>260     pub(crate) fn write_to<T: Write>(&mut self, wr: &mut T) -> io::Result<()> {
261         assert_eq!(self.pos, 0);
262 
263         // `write_all` already ignores interrupts
264         let res = wr.write_all(&self.buf);
265         self.buf.clear();
266         res
267     }
268 }
269 
270 cfg_fs! {
271     impl Buf {
272         pub(crate) fn discard_read(&mut self) -> i64 {
273             let ret = -(self.bytes().len() as i64);
274             self.pos = 0;
275             self.buf.truncate(0);
276             ret
277         }
278     }
279 }
280