1 #![warn(rust_2018_idioms)]
2 #![cfg(feature = "full")]
3
4 use tokio::io::{AsyncWrite, AsyncWriteExt};
5 use tokio_test::{assert_err, assert_ok};
6
7 use bytes::{Buf, Bytes, BytesMut};
8 use std::cmp;
9 use std::io;
10 use std::pin::Pin;
11 use std::task::{Context, Poll};
12
13 #[tokio::test]
write_all_buf()14 async fn write_all_buf() {
15 struct Wr {
16 buf: BytesMut,
17 cnt: usize,
18 }
19
20 impl AsyncWrite for Wr {
21 fn poll_write(
22 mut self: Pin<&mut Self>,
23 _cx: &mut Context<'_>,
24 buf: &[u8],
25 ) -> Poll<io::Result<usize>> {
26 let n = cmp::min(4, buf.len());
27 dbg!(buf);
28 let buf = &buf[0..n];
29
30 self.cnt += 1;
31 self.buf.extend(buf);
32 Ok(buf.len()).into()
33 }
34
35 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
36 Ok(()).into()
37 }
38
39 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
40 Ok(()).into()
41 }
42 }
43
44 let mut wr = Wr {
45 buf: BytesMut::with_capacity(64),
46 cnt: 0,
47 };
48
49 let mut buf = Bytes::from_static(b"hello").chain(Bytes::from_static(b"world"));
50
51 assert_ok!(wr.write_all_buf(&mut buf).await);
52 assert_eq!(wr.buf, b"helloworld"[..]);
53 // expect 4 writes, [hell],[o],[worl],[d]
54 assert_eq!(wr.cnt, 4);
55 assert_eq!(buf.has_remaining(), false);
56 }
57
58 #[tokio::test]
write_buf_err()59 async fn write_buf_err() {
60 /// Error out after writing the first 4 bytes
61 struct Wr {
62 cnt: usize,
63 }
64
65 impl AsyncWrite for Wr {
66 fn poll_write(
67 mut self: Pin<&mut Self>,
68 _cx: &mut Context<'_>,
69 _buf: &[u8],
70 ) -> Poll<io::Result<usize>> {
71 self.cnt += 1;
72 if self.cnt == 2 {
73 return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "whoops")));
74 }
75 Poll::Ready(Ok(4))
76 }
77
78 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
79 Ok(()).into()
80 }
81
82 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
83 Ok(()).into()
84 }
85 }
86
87 let mut wr = Wr { cnt: 0 };
88
89 let mut buf = Bytes::from_static(b"hello").chain(Bytes::from_static(b"world"));
90
91 assert_err!(wr.write_all_buf(&mut buf).await);
92 assert_eq!(
93 buf.copy_to_bytes(buf.remaining()),
94 Bytes::from_static(b"oworld")
95 );
96 }
97