1 use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; 2 3 use std::future::Future; 4 use std::io; 5 use std::pin::Pin; 6 use std::task::{Context, Poll}; 7 8 #[derive(Debug)] 9 pub(super) struct CopyBuffer { 10 read_done: bool, 11 need_flush: bool, 12 pos: usize, 13 cap: usize, 14 amt: u64, 15 buf: Box<[u8]>, 16 } 17 18 impl CopyBuffer { new() -> Self19 pub(super) fn new() -> Self { 20 Self { 21 read_done: false, 22 need_flush: false, 23 pos: 0, 24 cap: 0, 25 amt: 0, 26 buf: vec![0; super::DEFAULT_BUF_SIZE].into_boxed_slice(), 27 } 28 } 29 poll_copy<R, W>( &mut self, cx: &mut Context<'_>, mut reader: Pin<&mut R>, mut writer: Pin<&mut W>, ) -> Poll<io::Result<u64>> where R: AsyncRead + ?Sized, W: AsyncWrite + ?Sized,30 pub(super) fn poll_copy<R, W>( 31 &mut self, 32 cx: &mut Context<'_>, 33 mut reader: Pin<&mut R>, 34 mut writer: Pin<&mut W>, 35 ) -> Poll<io::Result<u64>> 36 where 37 R: AsyncRead + ?Sized, 38 W: AsyncWrite + ?Sized, 39 { 40 loop { 41 // If our buffer is empty, then we need to read some data to 42 // continue. 43 if self.pos == self.cap && !self.read_done { 44 let me = &mut *self; 45 let mut buf = ReadBuf::new(&mut me.buf); 46 47 match reader.as_mut().poll_read(cx, &mut buf) { 48 Poll::Ready(Ok(_)) => (), 49 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), 50 Poll::Pending => { 51 // Try flushing when the reader has no progress to avoid deadlock 52 // when the reader depends on buffered writer. 53 if self.need_flush { 54 ready!(writer.as_mut().poll_flush(cx))?; 55 self.need_flush = false; 56 } 57 58 return Poll::Pending; 59 } 60 } 61 62 let n = buf.filled().len(); 63 if n == 0 { 64 self.read_done = true; 65 } else { 66 self.pos = 0; 67 self.cap = n; 68 } 69 } 70 71 // If our buffer has some data, let's write it out! 72 while self.pos < self.cap { 73 let me = &mut *self; 74 let i = ready!(writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]))?; 75 if i == 0 { 76 return Poll::Ready(Err(io::Error::new( 77 io::ErrorKind::WriteZero, 78 "write zero byte into writer", 79 ))); 80 } else { 81 self.pos += i; 82 self.amt += i as u64; 83 self.need_flush = true; 84 } 85 } 86 87 // If pos larger than cap, this loop will never stop. 88 // In particular, user's wrong poll_write implementation returning 89 // incorrect written length may lead to thread blocking. 90 debug_assert!( 91 self.pos <= self.cap, 92 "writer returned length larger than input slice" 93 ); 94 95 // If we've written all the data and we've seen EOF, flush out the 96 // data and finish the transfer. 97 if self.pos == self.cap && self.read_done { 98 ready!(writer.as_mut().poll_flush(cx))?; 99 return Poll::Ready(Ok(self.amt)); 100 } 101 } 102 } 103 } 104 105 /// A future that asynchronously copies the entire contents of a reader into a 106 /// writer. 107 #[derive(Debug)] 108 #[must_use = "futures do nothing unless you `.await` or poll them"] 109 struct Copy<'a, R: ?Sized, W: ?Sized> { 110 reader: &'a mut R, 111 writer: &'a mut W, 112 buf: CopyBuffer, 113 } 114 115 cfg_io_util! { 116 /// Asynchronously copies the entire contents of a reader into a writer. 117 /// 118 /// This function returns a future that will continuously read data from 119 /// `reader` and then write it into `writer` in a streaming fashion until 120 /// `reader` returns EOF. 121 /// 122 /// On success, the total number of bytes that were copied from `reader` to 123 /// `writer` is returned. 124 /// 125 /// This is an asynchronous version of [`std::io::copy`][std]. 126 /// 127 /// [std]: std::io::copy 128 /// 129 /// # Errors 130 /// 131 /// The returned future will return an error immediately if any call to 132 /// `poll_read` or `poll_write` returns an error. 133 /// 134 /// # Examples 135 /// 136 /// ``` 137 /// use tokio::io; 138 /// 139 /// # async fn dox() -> std::io::Result<()> { 140 /// let mut reader: &[u8] = b"hello"; 141 /// let mut writer: Vec<u8> = vec![]; 142 /// 143 /// io::copy(&mut reader, &mut writer).await?; 144 /// 145 /// assert_eq!(&b"hello"[..], &writer[..]); 146 /// # Ok(()) 147 /// # } 148 /// ``` 149 pub async fn copy<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result<u64> 150 where 151 R: AsyncRead + Unpin + ?Sized, 152 W: AsyncWrite + Unpin + ?Sized, 153 { 154 Copy { 155 reader, 156 writer, 157 buf: CopyBuffer::new() 158 }.await 159 } 160 } 161 162 impl<R, W> Future for Copy<'_, R, W> 163 where 164 R: AsyncRead + Unpin + ?Sized, 165 W: AsyncWrite + Unpin + ?Sized, 166 { 167 type Output = io::Result<u64>; 168 poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>>169 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { 170 let me = &mut *self; 171 172 me.buf 173 .poll_copy(cx, Pin::new(&mut *me.reader), Pin::new(&mut *me.writer)) 174 } 175 } 176