1 #![warn(rust_2018_idioms)]
2 
3 use tokio::io::AsyncWrite;
4 use tokio_test::{assert_ready, task};
5 use tokio_util::codec::{Encoder, FramedWrite};
6 
7 use bytes::{BufMut, BytesMut};
8 use futures_sink::Sink;
9 use std::collections::VecDeque;
10 use std::io::{self, Write};
11 use std::pin::Pin;
12 use std::task::Poll::{Pending, Ready};
13 use std::task::{Context, Poll};
14 
15 macro_rules! mock {
16     ($($x:expr,)*) => {{
17         let mut v = VecDeque::new();
18         v.extend(vec![$($x),*]);
19         Mock { calls: v }
20     }};
21 }
22 
23 macro_rules! pin {
24     ($id:ident) => {
25         Pin::new(&mut $id)
26     };
27 }
28 
29 struct U32Encoder;
30 
31 impl Encoder<u32> for U32Encoder {
32     type Error = io::Error;
33 
encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()>34     fn encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()> {
35         // Reserve space
36         dst.reserve(4);
37         dst.put_u32(item);
38         Ok(())
39     }
40 }
41 
42 #[test]
write_multi_frame_in_packet()43 fn write_multi_frame_in_packet() {
44     let mut task = task::spawn(());
45     let mock = mock! {
46         Ok(b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()),
47     };
48     let mut framed = FramedWrite::new(mock, U32Encoder);
49 
50     task.enter(|cx, _| {
51         assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
52         assert!(pin!(framed).start_send(0).is_ok());
53         assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
54         assert!(pin!(framed).start_send(1).is_ok());
55         assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
56         assert!(pin!(framed).start_send(2).is_ok());
57 
58         // Nothing written yet
59         assert_eq!(1, framed.get_ref().calls.len());
60 
61         // Flush the writes
62         assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok());
63 
64         assert_eq!(0, framed.get_ref().calls.len());
65     });
66 }
67 
68 #[test]
write_hits_backpressure()69 fn write_hits_backpressure() {
70     const ITER: usize = 2 * 1024;
71 
72     let mut mock = mock! {
73         // Block the `ITER`th write
74         Err(io::Error::new(io::ErrorKind::WouldBlock, "not ready")),
75         Ok(b"".to_vec()),
76     };
77 
78     for i in 0..=ITER {
79         let mut b = BytesMut::with_capacity(4);
80         b.put_u32(i as u32);
81 
82         // Append to the end
83         match mock.calls.back_mut().unwrap() {
84             Ok(ref mut data) => {
85                 // Write in 2kb chunks
86                 if data.len() < ITER {
87                     data.extend_from_slice(&b[..]);
88                     continue;
89                 } // else fall through and create a new buffer
90             }
91             _ => unreachable!(),
92         }
93 
94         // Push a new new chunk
95         mock.calls.push_back(Ok(b[..].to_vec()));
96     }
97     // 1 'wouldblock', 4 * 2KB buffers, 1 b-byte buffer
98     assert_eq!(mock.calls.len(), 6);
99 
100     let mut task = task::spawn(());
101     let mut framed = FramedWrite::new(mock, U32Encoder);
102     task.enter(|cx, _| {
103         // Send 8KB. This fills up FramedWrite2 buffer
104         for i in 0..ITER {
105             assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
106             assert!(pin!(framed).start_send(i as u32).is_ok());
107         }
108 
109         // Now we poll_ready which forces a flush. The mock pops the front message
110         // and decides to block.
111         assert!(pin!(framed).poll_ready(cx).is_pending());
112 
113         // We poll again, forcing another flush, which this time succeeds
114         // The whole 8KB buffer is flushed
115         assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
116 
117         // Send more data. This matches the final message expected by the mock
118         assert!(pin!(framed).start_send(ITER as u32).is_ok());
119 
120         // Flush the rest of the buffer
121         assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok());
122 
123         // Ensure the mock is empty
124         assert_eq!(0, framed.get_ref().calls.len());
125     })
126 }
127 
128 // // ===== Mock ======
129 
130 struct Mock {
131     calls: VecDeque<io::Result<Vec<u8>>>,
132 }
133 
134 impl Write for Mock {
write(&mut self, src: &[u8]) -> io::Result<usize>135     fn write(&mut self, src: &[u8]) -> io::Result<usize> {
136         match self.calls.pop_front() {
137             Some(Ok(data)) => {
138                 assert!(src.len() >= data.len());
139                 assert_eq!(&data[..], &src[..data.len()]);
140                 Ok(data.len())
141             }
142             Some(Err(e)) => Err(e),
143             None => panic!("unexpected write; {:?}", src),
144         }
145     }
146 
flush(&mut self) -> io::Result<()>147     fn flush(&mut self) -> io::Result<()> {
148         Ok(())
149     }
150 }
151 
152 impl AsyncWrite for Mock {
poll_write( self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, io::Error>>153     fn poll_write(
154         self: Pin<&mut Self>,
155         _cx: &mut Context<'_>,
156         buf: &[u8],
157     ) -> Poll<Result<usize, io::Error>> {
158         match Pin::get_mut(self).write(buf) {
159             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending,
160             other => Ready(other),
161         }
162     }
poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>163     fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
164         match Pin::get_mut(self).flush() {
165             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending,
166             other => Ready(other),
167         }
168     }
poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>169     fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
170         unimplemented!()
171     }
172 }
173