1 // Originally sourced from `futures_util::io::buf_writer`, needs to be redefined locally so that
2 // the `AsyncBufWrite` impl can access its internals, and changed a bit to make it more efficient
3 // with those methods.
4 
5 use super::AsyncBufWrite;
6 use futures_core::ready;
7 use pin_project_lite::pin_project;
8 use std::{
9     cmp::min,
10     fmt, io,
11     pin::Pin,
12     task::{Context, Poll},
13 };
14 use tokio_03::io::AsyncWrite;
15 
16 const DEFAULT_BUF_SIZE: usize = 8192;
17 
18 pin_project! {
19     pub struct BufWriter<W> {
20         #[pin]
21         inner: W,
22         buf: Box<[u8]>,
23         written: usize,
24         buffered: usize,
25     }
26 }
27 
28 impl<W: AsyncWrite> BufWriter<W> {
29     /// Creates a new `BufWriter` with a default buffer capacity. The default is currently 8 KB,
30     /// but may change in the future.
new(inner: W) -> Self31     pub fn new(inner: W) -> Self {
32         Self::with_capacity(DEFAULT_BUF_SIZE, inner)
33     }
34 
35     /// Creates a new `BufWriter` with the specified buffer capacity.
with_capacity(cap: usize, inner: W) -> Self36     pub fn with_capacity(cap: usize, inner: W) -> Self {
37         Self {
38             inner,
39             buf: vec![0; cap].into(),
40             written: 0,
41             buffered: 0,
42         }
43     }
44 
partial_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>45     fn partial_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
46         let mut this = self.project();
47 
48         let mut ret = Ok(());
49         while *this.written < *this.buffered {
50             match this
51                 .inner
52                 .as_mut()
53                 .poll_write(cx, &this.buf[*this.written..*this.buffered])
54             {
55                 Poll::Pending => {
56                     break;
57                 }
58                 Poll::Ready(Ok(0)) => {
59                     ret = Err(io::Error::new(
60                         io::ErrorKind::WriteZero,
61                         "failed to write the buffered data",
62                     ));
63                     break;
64                 }
65                 Poll::Ready(Ok(n)) => *this.written += n,
66                 Poll::Ready(Err(e)) => {
67                     ret = Err(e);
68                     break;
69                 }
70             }
71         }
72 
73         if *this.written > 0 {
74             this.buf.copy_within(*this.written..*this.buffered, 0);
75             *this.buffered -= *this.written;
76             *this.written = 0;
77 
78             Poll::Ready(ret)
79         } else if *this.buffered == 0 {
80             Poll::Ready(ret)
81         } else {
82             ret?;
83             Poll::Pending
84         }
85     }
86 
flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>87     fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
88         let mut this = self.project();
89 
90         let mut ret = Ok(());
91         while *this.written < *this.buffered {
92             match ready!(this
93                 .inner
94                 .as_mut()
95                 .poll_write(cx, &this.buf[*this.written..*this.buffered]))
96             {
97                 Ok(0) => {
98                     ret = Err(io::Error::new(
99                         io::ErrorKind::WriteZero,
100                         "failed to write the buffered data",
101                     ));
102                     break;
103                 }
104                 Ok(n) => *this.written += n,
105                 Err(e) => {
106                     ret = Err(e);
107                     break;
108                 }
109             }
110         }
111         this.buf.copy_within(*this.written..*this.buffered, 0);
112         *this.buffered -= *this.written;
113         *this.written = 0;
114         Poll::Ready(ret)
115     }
116 
117     /// Gets a reference to the underlying writer.
get_ref(&self) -> &W118     pub fn get_ref(&self) -> &W {
119         &self.inner
120     }
121 
122     /// Gets a mutable reference to the underlying writer.
123     ///
124     /// It is inadvisable to directly write to the underlying writer.
get_mut(&mut self) -> &mut W125     pub fn get_mut(&mut self) -> &mut W {
126         &mut self.inner
127     }
128 
129     /// Gets a pinned mutable reference to the underlying writer.
130     ///
131     /// It is inadvisable to directly write to the underlying writer.
get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W>132     pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
133         self.project().inner
134     }
135 
136     /// Consumes this `BufWriter`, returning the underlying writer.
137     ///
138     /// Note that any leftover data in the internal buffer is lost.
into_inner(self) -> W139     pub fn into_inner(self) -> W {
140         self.inner
141     }
142 }
143 
144 impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {
poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>145     fn poll_write(
146         mut self: Pin<&mut Self>,
147         cx: &mut Context<'_>,
148         buf: &[u8],
149     ) -> Poll<io::Result<usize>> {
150         let this = self.as_mut().project();
151         if *this.buffered + buf.len() > this.buf.len() {
152             ready!(self.as_mut().partial_flush_buf(cx))?;
153         }
154 
155         let this = self.as_mut().project();
156         if buf.len() >= this.buf.len() {
157             if *this.buffered == 0 {
158                 this.inner.poll_write(cx, buf)
159             } else {
160                 // The only way that `partial_flush_buf` would have returned with
161                 // `this.buffered != 0` is if it were Pending, so our waker was already queued
162                 Poll::Pending
163             }
164         } else {
165             let len = min(this.buf.len() - *this.buffered, buf.len());
166             this.buf[*this.buffered..*this.buffered + len].copy_from_slice(&buf[..len]);
167             *this.buffered += len;
168             Poll::Ready(Ok(len))
169         }
170     }
171 
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>172     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
173         ready!(self.as_mut().flush_buf(cx))?;
174         self.project().inner.poll_flush(cx)
175     }
176 
poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>177     fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
178         ready!(self.as_mut().flush_buf(cx))?;
179         self.project().inner.poll_shutdown(cx)
180     }
181 }
182 
183 impl<W: AsyncWrite> AsyncBufWrite for BufWriter<W> {
poll_partial_flush_buf( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<io::Result<&mut [u8]>>184     fn poll_partial_flush_buf(
185         mut self: Pin<&mut Self>,
186         cx: &mut Context<'_>,
187     ) -> Poll<io::Result<&mut [u8]>> {
188         ready!(self.as_mut().partial_flush_buf(cx))?;
189         let this = self.project();
190         Poll::Ready(Ok(&mut this.buf[*this.buffered..]))
191     }
192 
produce(self: Pin<&mut Self>, amt: usize)193     fn produce(self: Pin<&mut Self>, amt: usize) {
194         *self.project().buffered += amt;
195     }
196 }
197 
198 impl<W: fmt::Debug> fmt::Debug for BufWriter<W> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result199     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200         f.debug_struct("BufWriter")
201             .field("writer", &self.inner)
202             .field(
203                 "buffer",
204                 &format_args!("{}/{}", self.buffered, self.buf.len()),
205             )
206             .field("written", &self.written)
207             .finish()
208     }
209 }
210