1 use crate::codec::UserError;
2 use crate::codec::UserError::*;
3 use crate::frame::{self, Frame, FrameSize};
4 use crate::hpack;
5 
6 use bytes::{Buf, BufMut, BytesMut};
7 use std::pin::Pin;
8 use std::task::{Context, Poll};
9 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10 
11 use std::io::{self, Cursor, IoSlice};
12 
13 // A macro to get around a method needing to borrow &mut self
14 macro_rules! limited_write_buf {
15     ($self:expr) => {{
16         let limit = $self.max_frame_size() + frame::HEADER_LEN;
17         $self.buf.get_mut().limit(limit)
18     }};
19 }
20 
21 #[derive(Debug)]
22 pub struct FramedWrite<T, B> {
23     /// Upstream `AsyncWrite`
24     inner: T,
25 
26     encoder: Encoder<B>,
27 }
28 
29 #[derive(Debug)]
30 struct Encoder<B> {
31     /// HPACK encoder
32     hpack: hpack::Encoder,
33 
34     /// Write buffer
35     ///
36     /// TODO: Should this be a ring buffer?
37     buf: Cursor<BytesMut>,
38 
39     /// Next frame to encode
40     next: Option<Next<B>>,
41 
42     /// Last data frame
43     last_data_frame: Option<frame::Data<B>>,
44 
45     /// Max frame size, this is specified by the peer
46     max_frame_size: FrameSize,
47 
48     /// Whether or not the wrapped `AsyncWrite` supports vectored IO.
49     is_write_vectored: bool,
50 }
51 
52 #[derive(Debug)]
53 enum Next<B> {
54     Data(frame::Data<B>),
55     Continuation(frame::Continuation),
56 }
57 
58 /// Initialize the connection with this amount of write buffer.
59 ///
60 /// The minimum MAX_FRAME_SIZE is 16kb, so always be able to send a HEADERS
61 /// frame that big.
62 const DEFAULT_BUFFER_CAPACITY: usize = 16 * 1_024;
63 
64 /// Min buffer required to attempt to write a frame
65 const MIN_BUFFER_CAPACITY: usize = frame::HEADER_LEN + CHAIN_THRESHOLD;
66 
67 /// Chain payloads bigger than this. The remote will never advertise a max frame
68 /// size less than this (well, the spec says the max frame size can't be less
69 /// than 16kb, so not even close).
70 const CHAIN_THRESHOLD: usize = 256;
71 
72 // TODO: Make generic
73 impl<T, B> FramedWrite<T, B>
74 where
75     T: AsyncWrite + Unpin,
76     B: Buf,
77 {
new(inner: T) -> FramedWrite<T, B>78     pub fn new(inner: T) -> FramedWrite<T, B> {
79         let is_write_vectored = inner.is_write_vectored();
80         FramedWrite {
81             inner,
82             encoder: Encoder {
83                 hpack: hpack::Encoder::default(),
84                 buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)),
85                 next: None,
86                 last_data_frame: None,
87                 max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE,
88                 is_write_vectored,
89             },
90         }
91     }
92 
93     /// Returns `Ready` when `send` is able to accept a frame
94     ///
95     /// Calling this function may result in the current contents of the buffer
96     /// to be flushed to `T`.
poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>>97     pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
98         if !self.encoder.has_capacity() {
99             // Try flushing
100             ready!(self.flush(cx))?;
101 
102             if !self.encoder.has_capacity() {
103                 return Poll::Pending;
104             }
105         }
106 
107         Poll::Ready(Ok(()))
108     }
109 
110     /// Buffer a frame.
111     ///
112     /// `poll_ready` must be called first to ensure that a frame may be
113     /// accepted.
buffer(&mut self, item: Frame<B>) -> Result<(), UserError>114     pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
115         self.encoder.buffer(item)
116     }
117 
118     /// Flush buffered data to the wire
flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>>119     pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
120         let span = tracing::trace_span!("FramedWrite::flush");
121         let _e = span.enter();
122 
123         loop {
124             while !self.encoder.is_empty() {
125                 match self.encoder.next {
126                     Some(Next::Data(ref mut frame)) => {
127                         tracing::trace!(queued_data_frame = true);
128                         let mut buf = (&mut self.encoder.buf).chain(frame.payload_mut());
129                         ready!(write(
130                             &mut self.inner,
131                             self.encoder.is_write_vectored,
132                             &mut buf,
133                             cx,
134                         ))?
135                     }
136                     _ => {
137                         tracing::trace!(queued_data_frame = false);
138                         ready!(write(
139                             &mut self.inner,
140                             self.encoder.is_write_vectored,
141                             &mut self.encoder.buf,
142                             cx,
143                         ))?
144                     }
145                 }
146             }
147 
148             match self.encoder.unset_frame() {
149                 ControlFlow::Continue => (),
150                 ControlFlow::Break => break,
151             }
152         }
153 
154         tracing::trace!("flushing buffer");
155         // Flush the upstream
156         ready!(Pin::new(&mut self.inner).poll_flush(cx))?;
157 
158         Poll::Ready(Ok(()))
159     }
160 
161     /// Close the codec
shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>>162     pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
163         ready!(self.flush(cx))?;
164         Pin::new(&mut self.inner).poll_shutdown(cx)
165     }
166 }
167 
write<T, B>( writer: &mut T, is_write_vectored: bool, buf: &mut B, cx: &mut Context<'_>, ) -> Poll<io::Result<()>> where T: AsyncWrite + Unpin, B: Buf,168 fn write<T, B>(
169     writer: &mut T,
170     is_write_vectored: bool,
171     buf: &mut B,
172     cx: &mut Context<'_>,
173 ) -> Poll<io::Result<()>>
174 where
175     T: AsyncWrite + Unpin,
176     B: Buf,
177 {
178     // TODO(eliza): when tokio-util 0.5.1 is released, this
179     // could just use `poll_write_buf`...
180     const MAX_IOVS: usize = 64;
181     let n = if is_write_vectored {
182         let mut bufs = [IoSlice::new(&[]); MAX_IOVS];
183         let cnt = buf.chunks_vectored(&mut bufs);
184         ready!(Pin::new(writer).poll_write_vectored(cx, &bufs[..cnt]))?
185     } else {
186         ready!(Pin::new(writer).poll_write(cx, buf.chunk()))?
187     };
188     buf.advance(n);
189     Ok(()).into()
190 }
191 
192 #[must_use]
193 enum ControlFlow {
194     Continue,
195     Break,
196 }
197 
198 impl<B> Encoder<B>
199 where
200     B: Buf,
201 {
unset_frame(&mut self) -> ControlFlow202     fn unset_frame(&mut self) -> ControlFlow {
203         // Clear internal buffer
204         self.buf.set_position(0);
205         self.buf.get_mut().clear();
206 
207         // The data frame has been written, so unset it
208         match self.next.take() {
209             Some(Next::Data(frame)) => {
210                 self.last_data_frame = Some(frame);
211                 debug_assert!(self.is_empty());
212                 ControlFlow::Break
213             }
214             Some(Next::Continuation(frame)) => {
215                 // Buffer the continuation frame, then try to write again
216                 let mut buf = limited_write_buf!(self);
217                 if let Some(continuation) = frame.encode(&mut buf) {
218                     self.next = Some(Next::Continuation(continuation));
219                 }
220                 ControlFlow::Continue
221             }
222             None => ControlFlow::Break,
223         }
224     }
225 
buffer(&mut self, item: Frame<B>) -> Result<(), UserError>226     fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
227         // Ensure that we have enough capacity to accept the write.
228         assert!(self.has_capacity());
229         let span = tracing::trace_span!("FramedWrite::buffer", frame = ?item);
230         let _e = span.enter();
231 
232         tracing::debug!(frame = ?item, "send");
233 
234         match item {
235             Frame::Data(mut v) => {
236                 // Ensure that the payload is not greater than the max frame.
237                 let len = v.payload().remaining();
238 
239                 if len > self.max_frame_size() {
240                     return Err(PayloadTooBig);
241                 }
242 
243                 if len >= CHAIN_THRESHOLD {
244                     let head = v.head();
245 
246                     // Encode the frame head to the buffer
247                     head.encode(len, self.buf.get_mut());
248 
249                     // Save the data frame
250                     self.next = Some(Next::Data(v));
251                 } else {
252                     v.encode_chunk(self.buf.get_mut());
253 
254                     // The chunk has been fully encoded, so there is no need to
255                     // keep it around
256                     assert_eq!(v.payload().remaining(), 0, "chunk not fully encoded");
257 
258                     // Save off the last frame...
259                     self.last_data_frame = Some(v);
260                 }
261             }
262             Frame::Headers(v) => {
263                 let mut buf = limited_write_buf!(self);
264                 if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) {
265                     self.next = Some(Next::Continuation(continuation));
266                 }
267             }
268             Frame::PushPromise(v) => {
269                 let mut buf = limited_write_buf!(self);
270                 if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) {
271                     self.next = Some(Next::Continuation(continuation));
272                 }
273             }
274             Frame::Settings(v) => {
275                 v.encode(self.buf.get_mut());
276                 tracing::trace!(rem = self.buf.remaining(), "encoded settings");
277             }
278             Frame::GoAway(v) => {
279                 v.encode(self.buf.get_mut());
280                 tracing::trace!(rem = self.buf.remaining(), "encoded go_away");
281             }
282             Frame::Ping(v) => {
283                 v.encode(self.buf.get_mut());
284                 tracing::trace!(rem = self.buf.remaining(), "encoded ping");
285             }
286             Frame::WindowUpdate(v) => {
287                 v.encode(self.buf.get_mut());
288                 tracing::trace!(rem = self.buf.remaining(), "encoded window_update");
289             }
290 
291             Frame::Priority(_) => {
292                 /*
293                 v.encode(self.buf.get_mut());
294                 tracing::trace!("encoded priority; rem={:?}", self.buf.remaining());
295                 */
296                 unimplemented!();
297             }
298             Frame::Reset(v) => {
299                 v.encode(self.buf.get_mut());
300                 tracing::trace!(rem = self.buf.remaining(), "encoded reset");
301             }
302         }
303 
304         Ok(())
305     }
306 
has_capacity(&self) -> bool307     fn has_capacity(&self) -> bool {
308         self.next.is_none() && self.buf.get_ref().remaining_mut() >= MIN_BUFFER_CAPACITY
309     }
310 
is_empty(&self) -> bool311     fn is_empty(&self) -> bool {
312         match self.next {
313             Some(Next::Data(ref frame)) => !frame.payload().has_remaining(),
314             _ => !self.buf.has_remaining(),
315         }
316     }
317 }
318 
319 impl<B> Encoder<B> {
max_frame_size(&self) -> usize320     fn max_frame_size(&self) -> usize {
321         self.max_frame_size as usize
322     }
323 }
324 
325 impl<T, B> FramedWrite<T, B> {
326     /// Returns the max frame size that can be sent
max_frame_size(&self) -> usize327     pub fn max_frame_size(&self) -> usize {
328         self.encoder.max_frame_size()
329     }
330 
331     /// Set the peer's max frame size.
set_max_frame_size(&mut self, val: usize)332     pub fn set_max_frame_size(&mut self, val: usize) {
333         assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize);
334         self.encoder.max_frame_size = val as FrameSize;
335     }
336 
337     /// Set the peer's header table size.
set_header_table_size(&mut self, val: usize)338     pub fn set_header_table_size(&mut self, val: usize) {
339         self.encoder.hpack.update_max_size(val);
340     }
341 
342     /// Retrieve the last data frame that has been sent
take_last_data_frame(&mut self) -> Option<frame::Data<B>>343     pub fn take_last_data_frame(&mut self) -> Option<frame::Data<B>> {
344         self.encoder.last_data_frame.take()
345     }
346 
get_mut(&mut self) -> &mut T347     pub fn get_mut(&mut self) -> &mut T {
348         &mut self.inner
349     }
350 }
351 
352 impl<T: AsyncRead + Unpin, B> AsyncRead for FramedWrite<T, B> {
poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf, ) -> Poll<io::Result<()>>353     fn poll_read(
354         mut self: Pin<&mut Self>,
355         cx: &mut Context<'_>,
356         buf: &mut ReadBuf,
357     ) -> Poll<io::Result<()>> {
358         Pin::new(&mut self.inner).poll_read(cx, buf)
359     }
360 }
361 
362 // We never project the Pin to `B`.
363 impl<T: Unpin, B> Unpin for FramedWrite<T, B> {}
364 
365 #[cfg(feature = "unstable")]
366 mod unstable {
367     use super::*;
368 
369     impl<T, B> FramedWrite<T, B> {
get_ref(&self) -> &T370         pub fn get_ref(&self) -> &T {
371             &self.inner
372         }
373     }
374 }
375