1 #![warn(rust_2018_idioms)]
2 #![cfg(feature = "full")]
3 
4 use tokio::io::{AsyncWrite, AsyncWriteExt};
5 use tokio_test::{assert_err, assert_ok};
6 
7 use bytes::{Buf, Bytes, BytesMut};
8 use std::cmp;
9 use std::io;
10 use std::pin::Pin;
11 use std::task::{Context, Poll};
12 
13 #[tokio::test]
write_all_buf()14 async fn write_all_buf() {
15     struct Wr {
16         buf: BytesMut,
17         cnt: usize,
18     }
19 
20     impl AsyncWrite for Wr {
21         fn poll_write(
22             mut self: Pin<&mut Self>,
23             _cx: &mut Context<'_>,
24             buf: &[u8],
25         ) -> Poll<io::Result<usize>> {
26             let n = cmp::min(4, buf.len());
27             dbg!(buf);
28             let buf = &buf[0..n];
29 
30             self.cnt += 1;
31             self.buf.extend(buf);
32             Ok(buf.len()).into()
33         }
34 
35         fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
36             Ok(()).into()
37         }
38 
39         fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
40             Ok(()).into()
41         }
42     }
43 
44     let mut wr = Wr {
45         buf: BytesMut::with_capacity(64),
46         cnt: 0,
47     };
48 
49     let mut buf = Bytes::from_static(b"hello").chain(Bytes::from_static(b"world"));
50 
51     assert_ok!(wr.write_all_buf(&mut buf).await);
52     assert_eq!(wr.buf, b"helloworld"[..]);
53     // expect 4 writes, [hell],[o],[worl],[d]
54     assert_eq!(wr.cnt, 4);
55     assert_eq!(buf.has_remaining(), false);
56 }
57 
58 #[tokio::test]
write_buf_err()59 async fn write_buf_err() {
60     /// Error out after writing the first 4 bytes
61     struct Wr {
62         cnt: usize,
63     }
64 
65     impl AsyncWrite for Wr {
66         fn poll_write(
67             mut self: Pin<&mut Self>,
68             _cx: &mut Context<'_>,
69             _buf: &[u8],
70         ) -> Poll<io::Result<usize>> {
71             self.cnt += 1;
72             if self.cnt == 2 {
73                 return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "whoops")));
74             }
75             Poll::Ready(Ok(4))
76         }
77 
78         fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
79             Ok(()).into()
80         }
81 
82         fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
83             Ok(()).into()
84         }
85     }
86 
87     let mut wr = Wr { cnt: 0 };
88 
89     let mut buf = Bytes::from_static(b"hello").chain(Bytes::from_static(b"world"));
90 
91     assert_err!(wr.write_all_buf(&mut buf).await);
92     assert_eq!(
93         buf.copy_to_bytes(buf.remaining()),
94         Bytes::from_static(b"oworld")
95     );
96 }
97