1 #![warn(rust_2018_idioms)]
2 #![cfg(feature = "full")]
3 
4 // https://github.com/rust-lang/futures-rs/blob/1803948ff091b4eabf7f3bf39e16bbbdefca5cc8/futures/tests/io_buf_writer.rs
5 
6 use futures::task::{Context, Poll};
7 use std::io::{self, Cursor};
8 use std::pin::Pin;
9 use tokio::io::{AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BufWriter, SeekFrom};
10 
11 use futures::future;
12 use tokio_test::assert_ok;
13 
14 use std::cmp;
15 use std::io::IoSlice;
16 
17 mod support {
18     pub(crate) mod io_vec;
19 }
20 use support::io_vec::IoBufs;
21 
22 struct MaybePending {
23     inner: Vec<u8>,
24     ready: bool,
25 }
26 
27 impl MaybePending {
new(inner: Vec<u8>) -> Self28     fn new(inner: Vec<u8>) -> Self {
29         Self {
30             inner,
31             ready: false,
32         }
33     }
34 }
35 
36 impl AsyncWrite for MaybePending {
poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>37     fn poll_write(
38         mut self: Pin<&mut Self>,
39         cx: &mut Context<'_>,
40         buf: &[u8],
41     ) -> Poll<io::Result<usize>> {
42         if self.ready {
43             self.ready = false;
44             Pin::new(&mut self.inner).poll_write(cx, buf)
45         } else {
46             self.ready = true;
47             cx.waker().wake_by_ref();
48             Poll::Pending
49         }
50     }
51 
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>52     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
53         Pin::new(&mut self.inner).poll_flush(cx)
54     }
55 
poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>56     fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
57         Pin::new(&mut self.inner).poll_shutdown(cx)
58     }
59 }
60 
write_vectored<W>(writer: &mut W, bufs: &[IoSlice<'_>]) -> io::Result<usize> where W: AsyncWrite + Unpin,61 async fn write_vectored<W>(writer: &mut W, bufs: &[IoSlice<'_>]) -> io::Result<usize>
62 where
63     W: AsyncWrite + Unpin,
64 {
65     let mut writer = Pin::new(writer);
66     future::poll_fn(|cx| writer.as_mut().poll_write_vectored(cx, bufs)).await
67 }
68 
69 #[tokio::test]
buf_writer()70 async fn buf_writer() {
71     let mut writer = BufWriter::with_capacity(2, Vec::new());
72 
73     writer.write(&[0, 1]).await.unwrap();
74     assert_eq!(writer.buffer(), []);
75     assert_eq!(*writer.get_ref(), [0, 1]);
76 
77     writer.write(&[2]).await.unwrap();
78     assert_eq!(writer.buffer(), [2]);
79     assert_eq!(*writer.get_ref(), [0, 1]);
80 
81     writer.write(&[3]).await.unwrap();
82     assert_eq!(writer.buffer(), [2, 3]);
83     assert_eq!(*writer.get_ref(), [0, 1]);
84 
85     writer.flush().await.unwrap();
86     assert_eq!(writer.buffer(), []);
87     assert_eq!(*writer.get_ref(), [0, 1, 2, 3]);
88 
89     writer.write(&[4]).await.unwrap();
90     writer.write(&[5]).await.unwrap();
91     assert_eq!(writer.buffer(), [4, 5]);
92     assert_eq!(*writer.get_ref(), [0, 1, 2, 3]);
93 
94     writer.write(&[6]).await.unwrap();
95     assert_eq!(writer.buffer(), [6]);
96     assert_eq!(*writer.get_ref(), [0, 1, 2, 3, 4, 5]);
97 
98     writer.write(&[7, 8]).await.unwrap();
99     assert_eq!(writer.buffer(), []);
100     assert_eq!(*writer.get_ref(), [0, 1, 2, 3, 4, 5, 6, 7, 8]);
101 
102     writer.write(&[9, 10, 11]).await.unwrap();
103     assert_eq!(writer.buffer(), []);
104     assert_eq!(*writer.get_ref(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]);
105 
106     writer.flush().await.unwrap();
107     assert_eq!(writer.buffer(), []);
108     assert_eq!(*writer.get_ref(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]);
109 }
110 
111 #[tokio::test]
buf_writer_inner_flushes()112 async fn buf_writer_inner_flushes() {
113     let mut w = BufWriter::with_capacity(3, Vec::new());
114     w.write(&[0, 1]).await.unwrap();
115     assert_eq!(*w.get_ref(), []);
116     w.flush().await.unwrap();
117     let w = w.into_inner();
118     assert_eq!(w, [0, 1]);
119 }
120 
121 #[tokio::test]
buf_writer_seek()122 async fn buf_writer_seek() {
123     let mut w = BufWriter::with_capacity(3, Cursor::new(Vec::new()));
124     w.write_all(&[0, 1, 2, 3, 4, 5]).await.unwrap();
125     w.write_all(&[6, 7]).await.unwrap();
126     assert_eq!(w.seek(SeekFrom::Current(0)).await.unwrap(), 8);
127     assert_eq!(&w.get_ref().get_ref()[..], &[0, 1, 2, 3, 4, 5, 6, 7][..]);
128     assert_eq!(w.seek(SeekFrom::Start(2)).await.unwrap(), 2);
129     w.write_all(&[8, 9]).await.unwrap();
130     w.flush().await.unwrap();
131     assert_eq!(&w.into_inner().into_inner()[..], &[0, 1, 8, 9, 4, 5, 6, 7]);
132 }
133 
134 #[tokio::test]
maybe_pending_buf_writer()135 async fn maybe_pending_buf_writer() {
136     let mut writer = BufWriter::with_capacity(2, MaybePending::new(Vec::new()));
137 
138     writer.write(&[0, 1]).await.unwrap();
139     assert_eq!(writer.buffer(), []);
140     assert_eq!(&writer.get_ref().inner, &[0, 1]);
141 
142     writer.write(&[2]).await.unwrap();
143     assert_eq!(writer.buffer(), [2]);
144     assert_eq!(&writer.get_ref().inner, &[0, 1]);
145 
146     writer.write(&[3]).await.unwrap();
147     assert_eq!(writer.buffer(), [2, 3]);
148     assert_eq!(&writer.get_ref().inner, &[0, 1]);
149 
150     writer.flush().await.unwrap();
151     assert_eq!(writer.buffer(), []);
152     assert_eq!(&writer.get_ref().inner, &[0, 1, 2, 3]);
153 
154     writer.write(&[4]).await.unwrap();
155     writer.write(&[5]).await.unwrap();
156     assert_eq!(writer.buffer(), [4, 5]);
157     assert_eq!(&writer.get_ref().inner, &[0, 1, 2, 3]);
158 
159     writer.write(&[6]).await.unwrap();
160     assert_eq!(writer.buffer(), [6]);
161     assert_eq!(writer.get_ref().inner, &[0, 1, 2, 3, 4, 5]);
162 
163     writer.write(&[7, 8]).await.unwrap();
164     assert_eq!(writer.buffer(), []);
165     assert_eq!(writer.get_ref().inner, &[0, 1, 2, 3, 4, 5, 6, 7, 8]);
166 
167     writer.write(&[9, 10, 11]).await.unwrap();
168     assert_eq!(writer.buffer(), []);
169     assert_eq!(
170         writer.get_ref().inner,
171         &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
172     );
173 
174     writer.flush().await.unwrap();
175     assert_eq!(writer.buffer(), []);
176     assert_eq!(
177         &writer.get_ref().inner,
178         &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
179     );
180 }
181 
182 #[tokio::test]
maybe_pending_buf_writer_inner_flushes()183 async fn maybe_pending_buf_writer_inner_flushes() {
184     let mut w = BufWriter::with_capacity(3, MaybePending::new(Vec::new()));
185     w.write(&[0, 1]).await.unwrap();
186     assert_eq!(&w.get_ref().inner, &[]);
187     w.flush().await.unwrap();
188     let w = w.into_inner().inner;
189     assert_eq!(w, [0, 1]);
190 }
191 
192 #[tokio::test]
maybe_pending_buf_writer_seek()193 async fn maybe_pending_buf_writer_seek() {
194     struct MaybePendingSeek {
195         inner: Cursor<Vec<u8>>,
196         ready_write: bool,
197         ready_seek: bool,
198         seek_res: Option<io::Result<()>>,
199     }
200 
201     impl MaybePendingSeek {
202         fn new(inner: Vec<u8>) -> Self {
203             Self {
204                 inner: Cursor::new(inner),
205                 ready_write: false,
206                 ready_seek: false,
207                 seek_res: None,
208             }
209         }
210     }
211 
212     impl AsyncWrite for MaybePendingSeek {
213         fn poll_write(
214             mut self: Pin<&mut Self>,
215             cx: &mut Context<'_>,
216             buf: &[u8],
217         ) -> Poll<io::Result<usize>> {
218             if self.ready_write {
219                 self.ready_write = false;
220                 Pin::new(&mut self.inner).poll_write(cx, buf)
221             } else {
222                 self.ready_write = true;
223                 cx.waker().wake_by_ref();
224                 Poll::Pending
225             }
226         }
227 
228         fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
229             Pin::new(&mut self.inner).poll_flush(cx)
230         }
231 
232         fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
233             Pin::new(&mut self.inner).poll_shutdown(cx)
234         }
235     }
236 
237     impl AsyncSeek for MaybePendingSeek {
238         fn start_seek(mut self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> {
239             self.seek_res = Some(Pin::new(&mut self.inner).start_seek(pos));
240             Ok(())
241         }
242         fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
243             if self.ready_seek {
244                 self.ready_seek = false;
245                 self.seek_res.take().unwrap_or(Ok(()))?;
246                 Pin::new(&mut self.inner).poll_complete(cx)
247             } else {
248                 self.ready_seek = true;
249                 cx.waker().wake_by_ref();
250                 Poll::Pending
251             }
252         }
253     }
254 
255     let mut w = BufWriter::with_capacity(3, MaybePendingSeek::new(Vec::new()));
256     w.write_all(&[0, 1, 2, 3, 4, 5]).await.unwrap();
257     w.write_all(&[6, 7]).await.unwrap();
258     assert_eq!(w.seek(SeekFrom::Current(0)).await.unwrap(), 8);
259     assert_eq!(
260         &w.get_ref().inner.get_ref()[..],
261         &[0, 1, 2, 3, 4, 5, 6, 7][..]
262     );
263     assert_eq!(w.seek(SeekFrom::Start(2)).await.unwrap(), 2);
264     w.write_all(&[8, 9]).await.unwrap();
265     w.flush().await.unwrap();
266     assert_eq!(
267         &w.into_inner().inner.into_inner()[..],
268         &[0, 1, 8, 9, 4, 5, 6, 7]
269     );
270 }
271 
272 struct MockWriter {
273     data: Vec<u8>,
274     write_len: usize,
275     vectored: bool,
276 }
277 
278 impl MockWriter {
new(write_len: usize) -> Self279     fn new(write_len: usize) -> Self {
280         MockWriter {
281             data: Vec::new(),
282             write_len,
283             vectored: false,
284         }
285     }
286 
vectored(write_len: usize) -> Self287     fn vectored(write_len: usize) -> Self {
288         MockWriter {
289             data: Vec::new(),
290             write_len,
291             vectored: true,
292         }
293     }
294 
write_up_to(&mut self, buf: &[u8], limit: usize) -> usize295     fn write_up_to(&mut self, buf: &[u8], limit: usize) -> usize {
296         let len = cmp::min(buf.len(), limit);
297         self.data.extend_from_slice(&buf[..len]);
298         len
299     }
300 }
301 
302 impl AsyncWrite for MockWriter {
poll_write( self: Pin<&mut Self>, _: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, io::Error>>303     fn poll_write(
304         self: Pin<&mut Self>,
305         _: &mut Context<'_>,
306         buf: &[u8],
307     ) -> Poll<Result<usize, io::Error>> {
308         let this = self.get_mut();
309         let n = this.write_up_to(buf, this.write_len);
310         Ok(n).into()
311     }
312 
poll_write_vectored( self: Pin<&mut Self>, _: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<Result<usize, io::Error>>313     fn poll_write_vectored(
314         self: Pin<&mut Self>,
315         _: &mut Context<'_>,
316         bufs: &[IoSlice<'_>],
317     ) -> Poll<Result<usize, io::Error>> {
318         let this = self.get_mut();
319         let mut total_written = 0;
320         for buf in bufs {
321             let n = this.write_up_to(buf, this.write_len - total_written);
322             total_written += n;
323             if total_written == this.write_len {
324                 break;
325             }
326         }
327         Ok(total_written).into()
328     }
329 
is_write_vectored(&self) -> bool330     fn is_write_vectored(&self) -> bool {
331         self.vectored
332     }
333 
poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>>334     fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
335         Ok(()).into()
336     }
337 
poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>>338     fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
339         Ok(()).into()
340     }
341 }
342 
343 #[tokio::test]
write_vectored_empty_on_non_vectored()344 async fn write_vectored_empty_on_non_vectored() {
345     let mut w = BufWriter::new(MockWriter::new(4));
346     let n = assert_ok!(write_vectored(&mut w, &[]).await);
347     assert_eq!(n, 0);
348 
349     let io_vec = [IoSlice::new(&[]); 3];
350     let n = assert_ok!(write_vectored(&mut w, &io_vec).await);
351     assert_eq!(n, 0);
352 
353     assert_ok!(w.flush().await);
354     assert!(w.get_ref().data.is_empty());
355 }
356 
357 #[tokio::test]
write_vectored_empty_on_vectored()358 async fn write_vectored_empty_on_vectored() {
359     let mut w = BufWriter::new(MockWriter::vectored(4));
360     let n = assert_ok!(write_vectored(&mut w, &[]).await);
361     assert_eq!(n, 0);
362 
363     let io_vec = [IoSlice::new(&[]); 3];
364     let n = assert_ok!(write_vectored(&mut w, &io_vec).await);
365     assert_eq!(n, 0);
366 
367     assert_ok!(w.flush().await);
368     assert!(w.get_ref().data.is_empty());
369 }
370 
371 #[tokio::test]
write_vectored_basic_on_non_vectored()372 async fn write_vectored_basic_on_non_vectored() {
373     let msg = b"foo bar baz";
374     let bufs = [
375         IoSlice::new(&msg[0..4]),
376         IoSlice::new(&msg[4..8]),
377         IoSlice::new(&msg[8..]),
378     ];
379     let mut w = BufWriter::new(MockWriter::new(4));
380     let n = assert_ok!(write_vectored(&mut w, &bufs).await);
381     assert_eq!(n, msg.len());
382     assert!(w.buffer() == &msg[..]);
383     assert_ok!(w.flush().await);
384     assert_eq!(w.get_ref().data, msg);
385 }
386 
387 #[tokio::test]
write_vectored_basic_on_vectored()388 async fn write_vectored_basic_on_vectored() {
389     let msg = b"foo bar baz";
390     let bufs = [
391         IoSlice::new(&msg[0..4]),
392         IoSlice::new(&msg[4..8]),
393         IoSlice::new(&msg[8..]),
394     ];
395     let mut w = BufWriter::new(MockWriter::vectored(4));
396     let n = assert_ok!(write_vectored(&mut w, &bufs).await);
397     assert_eq!(n, msg.len());
398     assert!(w.buffer() == &msg[..]);
399     assert_ok!(w.flush().await);
400     assert_eq!(w.get_ref().data, msg);
401 }
402 
403 #[tokio::test]
write_vectored_large_total_on_non_vectored()404 async fn write_vectored_large_total_on_non_vectored() {
405     let msg = b"foo bar baz";
406     let mut bufs = [
407         IoSlice::new(&msg[0..4]),
408         IoSlice::new(&msg[4..8]),
409         IoSlice::new(&msg[8..]),
410     ];
411     let io_vec = IoBufs::new(&mut bufs);
412     let mut w = BufWriter::with_capacity(8, MockWriter::new(4));
413     let n = assert_ok!(write_vectored(&mut w, &io_vec).await);
414     assert_eq!(n, 8);
415     assert!(w.buffer() == &msg[..8]);
416     let io_vec = io_vec.advance(n);
417     let n = assert_ok!(write_vectored(&mut w, &io_vec).await);
418     assert_eq!(n, 3);
419     assert!(w.get_ref().data.as_slice() == &msg[..8]);
420     assert!(w.buffer() == &msg[8..]);
421 }
422 
423 #[tokio::test]
write_vectored_large_total_on_vectored()424 async fn write_vectored_large_total_on_vectored() {
425     let msg = b"foo bar baz";
426     let mut bufs = [
427         IoSlice::new(&msg[0..4]),
428         IoSlice::new(&msg[4..8]),
429         IoSlice::new(&msg[8..]),
430     ];
431     let io_vec = IoBufs::new(&mut bufs);
432     let mut w = BufWriter::with_capacity(8, MockWriter::vectored(10));
433     let n = assert_ok!(write_vectored(&mut w, &io_vec).await);
434     assert_eq!(n, 10);
435     assert!(w.buffer().is_empty());
436     let io_vec = io_vec.advance(n);
437     let n = assert_ok!(write_vectored(&mut w, &io_vec).await);
438     assert_eq!(n, 1);
439     assert!(w.get_ref().data.as_slice() == &msg[..10]);
440     assert!(w.buffer() == &msg[10..]);
441 }
442 
443 struct VectoredWriteHarness {
444     writer: BufWriter<MockWriter>,
445     buf_capacity: usize,
446 }
447 
448 impl VectoredWriteHarness {
new(buf_capacity: usize) -> Self449     fn new(buf_capacity: usize) -> Self {
450         VectoredWriteHarness {
451             writer: BufWriter::with_capacity(buf_capacity, MockWriter::new(4)),
452             buf_capacity,
453         }
454     }
455 
with_vectored_backend(buf_capacity: usize) -> Self456     fn with_vectored_backend(buf_capacity: usize) -> Self {
457         VectoredWriteHarness {
458             writer: BufWriter::with_capacity(buf_capacity, MockWriter::vectored(4)),
459             buf_capacity,
460         }
461     }
462 
write_all<'a, 'b>(&mut self, mut io_vec: IoBufs<'a, 'b>) -> usize463     async fn write_all<'a, 'b>(&mut self, mut io_vec: IoBufs<'a, 'b>) -> usize {
464         let mut total_written = 0;
465         while !io_vec.is_empty() {
466             let n = assert_ok!(write_vectored(&mut self.writer, &io_vec).await);
467             assert!(n != 0);
468             assert!(self.writer.buffer().len() <= self.buf_capacity);
469             total_written += n;
470             io_vec = io_vec.advance(n);
471         }
472         total_written
473     }
474 
flush(&mut self) -> &[u8]475     async fn flush(&mut self) -> &[u8] {
476         assert_ok!(self.writer.flush().await);
477         &self.writer.get_ref().data
478     }
479 }
480 
481 #[tokio::test]
write_vectored_odd_on_non_vectored()482 async fn write_vectored_odd_on_non_vectored() {
483     let msg = b"foo bar baz";
484     let mut bufs = [
485         IoSlice::new(&msg[0..4]),
486         IoSlice::new(&[]),
487         IoSlice::new(&msg[4..9]),
488         IoSlice::new(&msg[9..]),
489     ];
490     let mut h = VectoredWriteHarness::new(8);
491     let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await;
492     assert_eq!(bytes_written, msg.len());
493     assert_eq!(h.flush().await, msg);
494 }
495 
496 #[tokio::test]
write_vectored_odd_on_vectored()497 async fn write_vectored_odd_on_vectored() {
498     let msg = b"foo bar baz";
499     let mut bufs = [
500         IoSlice::new(&msg[0..4]),
501         IoSlice::new(&[]),
502         IoSlice::new(&msg[4..9]),
503         IoSlice::new(&msg[9..]),
504     ];
505     let mut h = VectoredWriteHarness::with_vectored_backend(8);
506     let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await;
507     assert_eq!(bytes_written, msg.len());
508     assert_eq!(h.flush().await, msg);
509 }
510 
511 #[tokio::test]
write_vectored_large_slice_on_non_vectored()512 async fn write_vectored_large_slice_on_non_vectored() {
513     let msg = b"foo bar baz";
514     let mut bufs = [
515         IoSlice::new(&[]),
516         IoSlice::new(&msg[..9]),
517         IoSlice::new(&msg[9..]),
518     ];
519     let mut h = VectoredWriteHarness::new(8);
520     let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await;
521     assert_eq!(bytes_written, msg.len());
522     assert_eq!(h.flush().await, msg);
523 }
524 
525 #[tokio::test]
write_vectored_large_slice_on_vectored()526 async fn write_vectored_large_slice_on_vectored() {
527     let msg = b"foo bar baz";
528     let mut bufs = [
529         IoSlice::new(&[]),
530         IoSlice::new(&msg[..9]),
531         IoSlice::new(&msg[9..]),
532     ];
533     let mut h = VectoredWriteHarness::with_vectored_backend(8);
534     let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await;
535     assert_eq!(bytes_written, msg.len());
536     assert_eq!(h.flush().await, msg);
537 }
538