1 // Originally sourced from `futures_util::io::buf_writer`, needs to be redefined locally so that 2 // the `AsyncBufWrite` impl can access its internals, and changed a bit to make it more efficient 3 // with those methods. 4 5 use super::AsyncBufWrite; 6 use futures_core::ready; 7 use pin_project_lite::pin_project; 8 use std::{ 9 cmp::min, 10 fmt, io, 11 pin::Pin, 12 task::{Context, Poll}, 13 }; 14 use tokio_03::io::AsyncWrite; 15 16 const DEFAULT_BUF_SIZE: usize = 8192; 17 18 pin_project! { 19 pub struct BufWriter<W> { 20 #[pin] 21 inner: W, 22 buf: Box<[u8]>, 23 written: usize, 24 buffered: usize, 25 } 26 } 27 28 impl<W: AsyncWrite> BufWriter<W> { 29 /// Creates a new `BufWriter` with a default buffer capacity. The default is currently 8 KB, 30 /// but may change in the future. new(inner: W) -> Self31 pub fn new(inner: W) -> Self { 32 Self::with_capacity(DEFAULT_BUF_SIZE, inner) 33 } 34 35 /// Creates a new `BufWriter` with the specified buffer capacity. with_capacity(cap: usize, inner: W) -> Self36 pub fn with_capacity(cap: usize, inner: W) -> Self { 37 Self { 38 inner, 39 buf: vec![0; cap].into(), 40 written: 0, 41 buffered: 0, 42 } 43 } 44 partial_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>45 fn partial_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 46 let mut this = self.project(); 47 48 let mut ret = Ok(()); 49 while *this.written < *this.buffered { 50 match this 51 .inner 52 .as_mut() 53 .poll_write(cx, &this.buf[*this.written..*this.buffered]) 54 { 55 Poll::Pending => { 56 break; 57 } 58 Poll::Ready(Ok(0)) => { 59 ret = Err(io::Error::new( 60 io::ErrorKind::WriteZero, 61 "failed to write the buffered data", 62 )); 63 break; 64 } 65 Poll::Ready(Ok(n)) => *this.written += n, 66 Poll::Ready(Err(e)) => { 67 ret = Err(e); 68 break; 69 } 70 } 71 } 72 73 if *this.written > 0 { 74 this.buf.copy_within(*this.written..*this.buffered, 0); 75 *this.buffered -= *this.written; 76 *this.written = 0; 77 78 Poll::Ready(ret) 79 } else if *this.buffered == 0 { 80 Poll::Ready(ret) 81 } else { 82 ret?; 83 Poll::Pending 84 } 85 } 86 flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>87 fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 88 let mut this = self.project(); 89 90 let mut ret = Ok(()); 91 while *this.written < *this.buffered { 92 match ready!(this 93 .inner 94 .as_mut() 95 .poll_write(cx, &this.buf[*this.written..*this.buffered])) 96 { 97 Ok(0) => { 98 ret = Err(io::Error::new( 99 io::ErrorKind::WriteZero, 100 "failed to write the buffered data", 101 )); 102 break; 103 } 104 Ok(n) => *this.written += n, 105 Err(e) => { 106 ret = Err(e); 107 break; 108 } 109 } 110 } 111 this.buf.copy_within(*this.written..*this.buffered, 0); 112 *this.buffered -= *this.written; 113 *this.written = 0; 114 Poll::Ready(ret) 115 } 116 117 /// Gets a reference to the underlying writer. get_ref(&self) -> &W118 pub fn get_ref(&self) -> &W { 119 &self.inner 120 } 121 122 /// Gets a mutable reference to the underlying writer. 123 /// 124 /// It is inadvisable to directly write to the underlying writer. get_mut(&mut self) -> &mut W125 pub fn get_mut(&mut self) -> &mut W { 126 &mut self.inner 127 } 128 129 /// Gets a pinned mutable reference to the underlying writer. 130 /// 131 /// It is inadvisable to directly write to the underlying writer. get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W>132 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> { 133 self.project().inner 134 } 135 136 /// Consumes this `BufWriter`, returning the underlying writer. 137 /// 138 /// Note that any leftover data in the internal buffer is lost. into_inner(self) -> W139 pub fn into_inner(self) -> W { 140 self.inner 141 } 142 } 143 144 impl<W: AsyncWrite> AsyncWrite for BufWriter<W> { poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>145 fn poll_write( 146 mut self: Pin<&mut Self>, 147 cx: &mut Context<'_>, 148 buf: &[u8], 149 ) -> Poll<io::Result<usize>> { 150 let this = self.as_mut().project(); 151 if *this.buffered + buf.len() > this.buf.len() { 152 ready!(self.as_mut().partial_flush_buf(cx))?; 153 } 154 155 let this = self.as_mut().project(); 156 if buf.len() >= this.buf.len() { 157 if *this.buffered == 0 { 158 this.inner.poll_write(cx, buf) 159 } else { 160 // The only way that `partial_flush_buf` would have returned with 161 // `this.buffered != 0` is if it were Pending, so our waker was already queued 162 Poll::Pending 163 } 164 } else { 165 let len = min(this.buf.len() - *this.buffered, buf.len()); 166 this.buf[*this.buffered..*this.buffered + len].copy_from_slice(&buf[..len]); 167 *this.buffered += len; 168 Poll::Ready(Ok(len)) 169 } 170 } 171 poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>172 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 173 ready!(self.as_mut().flush_buf(cx))?; 174 self.project().inner.poll_flush(cx) 175 } 176 poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>177 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 178 ready!(self.as_mut().flush_buf(cx))?; 179 self.project().inner.poll_shutdown(cx) 180 } 181 } 182 183 impl<W: AsyncWrite> AsyncBufWrite for BufWriter<W> { poll_partial_flush_buf( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<io::Result<&mut [u8]>>184 fn poll_partial_flush_buf( 185 mut self: Pin<&mut Self>, 186 cx: &mut Context<'_>, 187 ) -> Poll<io::Result<&mut [u8]>> { 188 ready!(self.as_mut().partial_flush_buf(cx))?; 189 let this = self.project(); 190 Poll::Ready(Ok(&mut this.buf[*this.buffered..])) 191 } 192 produce(self: Pin<&mut Self>, amt: usize)193 fn produce(self: Pin<&mut Self>, amt: usize) { 194 *self.project().buffered += amt; 195 } 196 } 197 198 impl<W: fmt::Debug> fmt::Debug for BufWriter<W> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result199 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 200 f.debug_struct("BufWriter") 201 .field("writer", &self.inner) 202 .field( 203 "buffer", 204 &format_args!("{}/{}", self.buffered, self.buf.len()), 205 ) 206 .field("written", &self.written) 207 .finish() 208 } 209 } 210