1 use crate::codec::UserError;
2 use crate::codec::UserError::*;
3 use crate::frame::{self, Frame, FrameSize};
4 use crate::hpack;
5 
6 use bytes::{
7     buf::{BufExt, BufMutExt},
8     Buf, BufMut, BytesMut,
9 };
10 use std::pin::Pin;
11 use std::task::{Context, Poll};
12 use tokio::io::{AsyncRead, AsyncWrite};
13 
14 use std::io::{self, Cursor};
15 
16 // A macro to get around a method needing to borrow &mut self
17 macro_rules! limited_write_buf {
18     ($self:expr) => {{
19         let limit = $self.max_frame_size() + frame::HEADER_LEN;
20         $self.buf.get_mut().limit(limit)
21     }};
22 }
23 
24 #[derive(Debug)]
25 pub struct FramedWrite<T, B> {
26     /// Upstream `AsyncWrite`
27     inner: T,
28 
29     /// HPACK encoder
30     hpack: hpack::Encoder,
31 
32     /// Write buffer
33     ///
34     /// TODO: Should this be a ring buffer?
35     buf: Cursor<BytesMut>,
36 
37     /// Next frame to encode
38     next: Option<Next<B>>,
39 
40     /// Last data frame
41     last_data_frame: Option<frame::Data<B>>,
42 
43     /// Max frame size, this is specified by the peer
44     max_frame_size: FrameSize,
45 }
46 
47 #[derive(Debug)]
48 enum Next<B> {
49     Data(frame::Data<B>),
50     Continuation(frame::Continuation),
51 }
52 
53 /// Initialze the connection with this amount of write buffer.
54 ///
55 /// The minimum MAX_FRAME_SIZE is 16kb, so always be able to send a HEADERS
56 /// frame that big.
57 const DEFAULT_BUFFER_CAPACITY: usize = 16 * 1_024;
58 
59 /// Min buffer required to attempt to write a frame
60 const MIN_BUFFER_CAPACITY: usize = frame::HEADER_LEN + CHAIN_THRESHOLD;
61 
62 /// Chain payloads bigger than this. The remote will never advertise a max frame
63 /// size less than this (well, the spec says the max frame size can't be less
64 /// than 16kb, so not even close).
65 const CHAIN_THRESHOLD: usize = 256;
66 
67 // TODO: Make generic
68 impl<T, B> FramedWrite<T, B>
69 where
70     T: AsyncWrite + Unpin,
71     B: Buf,
72 {
new(inner: T) -> FramedWrite<T, B>73     pub fn new(inner: T) -> FramedWrite<T, B> {
74         FramedWrite {
75             inner,
76             hpack: hpack::Encoder::default(),
77             buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)),
78             next: None,
79             last_data_frame: None,
80             max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE,
81         }
82     }
83 
84     /// Returns `Ready` when `send` is able to accept a frame
85     ///
86     /// Calling this function may result in the current contents of the buffer
87     /// to be flushed to `T`.
poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>>88     pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
89         if !self.has_capacity() {
90             // Try flushing
91             ready!(self.flush(cx))?;
92 
93             if !self.has_capacity() {
94                 return Poll::Pending;
95             }
96         }
97 
98         Poll::Ready(Ok(()))
99     }
100 
101     /// Buffer a frame.
102     ///
103     /// `poll_ready` must be called first to ensure that a frame may be
104     /// accepted.
buffer(&mut self, item: Frame<B>) -> Result<(), UserError>105     pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
106         // Ensure that we have enough capacity to accept the write.
107         assert!(self.has_capacity());
108 
109         tracing::debug!("send; frame={:?}", item);
110 
111         match item {
112             Frame::Data(mut v) => {
113                 // Ensure that the payload is not greater than the max frame.
114                 let len = v.payload().remaining();
115 
116                 if len > self.max_frame_size() {
117                     return Err(PayloadTooBig);
118                 }
119 
120                 if len >= CHAIN_THRESHOLD {
121                     let head = v.head();
122 
123                     // Encode the frame head to the buffer
124                     head.encode(len, self.buf.get_mut());
125 
126                     // Save the data frame
127                     self.next = Some(Next::Data(v));
128                 } else {
129                     v.encode_chunk(self.buf.get_mut());
130 
131                     // The chunk has been fully encoded, so there is no need to
132                     // keep it around
133                     assert_eq!(v.payload().remaining(), 0, "chunk not fully encoded");
134 
135                     // Save off the last frame...
136                     self.last_data_frame = Some(v);
137                 }
138             }
139             Frame::Headers(v) => {
140                 let mut buf = limited_write_buf!(self);
141                 if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) {
142                     self.next = Some(Next::Continuation(continuation));
143                 }
144             }
145             Frame::PushPromise(v) => {
146                 let mut buf = limited_write_buf!(self);
147                 if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) {
148                     self.next = Some(Next::Continuation(continuation));
149                 }
150             }
151             Frame::Settings(v) => {
152                 v.encode(self.buf.get_mut());
153                 tracing::trace!("encoded settings; rem={:?}", self.buf.remaining());
154             }
155             Frame::GoAway(v) => {
156                 v.encode(self.buf.get_mut());
157                 tracing::trace!("encoded go_away; rem={:?}", self.buf.remaining());
158             }
159             Frame::Ping(v) => {
160                 v.encode(self.buf.get_mut());
161                 tracing::trace!("encoded ping; rem={:?}", self.buf.remaining());
162             }
163             Frame::WindowUpdate(v) => {
164                 v.encode(self.buf.get_mut());
165                 tracing::trace!("encoded window_update; rem={:?}", self.buf.remaining());
166             }
167 
168             Frame::Priority(_) => {
169                 /*
170                 v.encode(self.buf.get_mut());
171                 tracing::trace!("encoded priority; rem={:?}", self.buf.remaining());
172                 */
173                 unimplemented!();
174             }
175             Frame::Reset(v) => {
176                 v.encode(self.buf.get_mut());
177                 tracing::trace!("encoded reset; rem={:?}", self.buf.remaining());
178             }
179         }
180 
181         Ok(())
182     }
183 
184     /// Flush buffered data to the wire
flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>>185     pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
186         tracing::trace!("flush");
187 
188         loop {
189             while !self.is_empty() {
190                 match self.next {
191                     Some(Next::Data(ref mut frame)) => {
192                         tracing::trace!("  -> queued data frame");
193                         let mut buf = (&mut self.buf).chain(frame.payload_mut());
194                         ready!(Pin::new(&mut self.inner).poll_write_buf(cx, &mut buf))?;
195                     }
196                     _ => {
197                         tracing::trace!("  -> not a queued data frame");
198                         ready!(Pin::new(&mut self.inner).poll_write_buf(cx, &mut self.buf))?;
199                     }
200                 }
201             }
202 
203             // Clear internal buffer
204             self.buf.set_position(0);
205             self.buf.get_mut().clear();
206 
207             // The data frame has been written, so unset it
208             match self.next.take() {
209                 Some(Next::Data(frame)) => {
210                     self.last_data_frame = Some(frame);
211                     debug_assert!(self.is_empty());
212                     break;
213                 }
214                 Some(Next::Continuation(frame)) => {
215                     // Buffer the continuation frame, then try to write again
216                     let mut buf = limited_write_buf!(self);
217                     if let Some(continuation) = frame.encode(&mut self.hpack, &mut buf) {
218                         // We previously had a CONTINUATION, and after encoding
219                         // it, we got *another* one? Let's just double check
220                         // that at least some progress is being made...
221                         if self.buf.get_ref().len() == frame::HEADER_LEN {
222                             // If *only* the CONTINUATION frame header was
223                             // written, and *no* header fields, we're stuck
224                             // in a loop...
225                             panic!("CONTINUATION frame write loop; header value too big to encode");
226                         }
227 
228                         self.next = Some(Next::Continuation(continuation));
229                     }
230                 }
231                 None => {
232                     break;
233                 }
234             }
235         }
236 
237         tracing::trace!("flushing buffer");
238         // Flush the upstream
239         ready!(Pin::new(&mut self.inner).poll_flush(cx))?;
240 
241         Poll::Ready(Ok(()))
242     }
243 
244     /// Close the codec
shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>>245     pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
246         ready!(self.flush(cx))?;
247         Pin::new(&mut self.inner).poll_shutdown(cx)
248     }
249 
has_capacity(&self) -> bool250     fn has_capacity(&self) -> bool {
251         self.next.is_none() && self.buf.get_ref().remaining_mut() >= MIN_BUFFER_CAPACITY
252     }
253 
is_empty(&self) -> bool254     fn is_empty(&self) -> bool {
255         match self.next {
256             Some(Next::Data(ref frame)) => !frame.payload().has_remaining(),
257             _ => !self.buf.has_remaining(),
258         }
259     }
260 }
261 
262 impl<T, B> FramedWrite<T, B> {
263     /// Returns the max frame size that can be sent
max_frame_size(&self) -> usize264     pub fn max_frame_size(&self) -> usize {
265         self.max_frame_size as usize
266     }
267 
268     /// Set the peer's max frame size.
set_max_frame_size(&mut self, val: usize)269     pub fn set_max_frame_size(&mut self, val: usize) {
270         assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize);
271         self.max_frame_size = val as FrameSize;
272     }
273 
274     /// Set the peer's header table size.
set_header_table_size(&mut self, val: usize)275     pub fn set_header_table_size(&mut self, val: usize) {
276         self.hpack.update_max_size(val);
277     }
278 
279     /// Retrieve the last data frame that has been sent
take_last_data_frame(&mut self) -> Option<frame::Data<B>>280     pub fn take_last_data_frame(&mut self) -> Option<frame::Data<B>> {
281         self.last_data_frame.take()
282     }
283 
get_mut(&mut self) -> &mut T284     pub fn get_mut(&mut self) -> &mut T {
285         &mut self.inner
286     }
287 }
288 
289 impl<T: AsyncRead + Unpin, B> AsyncRead for FramedWrite<T, B> {
prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool290     unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
291         self.inner.prepare_uninitialized_buffer(buf)
292     }
293 
poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll<io::Result<usize>>294     fn poll_read(
295         mut self: Pin<&mut Self>,
296         cx: &mut Context<'_>,
297         buf: &mut [u8],
298     ) -> Poll<io::Result<usize>> {
299         Pin::new(&mut self.inner).poll_read(cx, buf)
300     }
301 
poll_read_buf<Buf: BufMut>( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut Buf, ) -> Poll<io::Result<usize>>302     fn poll_read_buf<Buf: BufMut>(
303         mut self: Pin<&mut Self>,
304         cx: &mut Context<'_>,
305         buf: &mut Buf,
306     ) -> Poll<io::Result<usize>> {
307         Pin::new(&mut self.inner).poll_read_buf(cx, buf)
308     }
309 }
310 
311 // We never project the Pin to `B`.
312 impl<T: Unpin, B> Unpin for FramedWrite<T, B> {}
313 
314 #[cfg(feature = "unstable")]
315 mod unstable {
316     use super::*;
317 
318     impl<T, B> FramedWrite<T, B> {
get_ref(&self) -> &T319         pub fn get_ref(&self) -> &T {
320             &self.inner
321         }
322     }
323 }
324