1 use crate::codec::UserError; 2 use crate::codec::UserError::*; 3 use crate::frame::{self, Frame, FrameSize}; 4 use crate::hpack; 5 6 use bytes::{ 7 buf::{BufExt, BufMutExt}, 8 Buf, BufMut, BytesMut, 9 }; 10 use std::pin::Pin; 11 use std::task::{Context, Poll}; 12 use tokio::io::{AsyncRead, AsyncWrite}; 13 14 use std::io::{self, Cursor}; 15 16 // A macro to get around a method needing to borrow &mut self 17 macro_rules! limited_write_buf { 18 ($self:expr) => {{ 19 let limit = $self.max_frame_size() + frame::HEADER_LEN; 20 $self.buf.get_mut().limit(limit) 21 }}; 22 } 23 24 #[derive(Debug)] 25 pub struct FramedWrite<T, B> { 26 /// Upstream `AsyncWrite` 27 inner: T, 28 29 /// HPACK encoder 30 hpack: hpack::Encoder, 31 32 /// Write buffer 33 /// 34 /// TODO: Should this be a ring buffer? 35 buf: Cursor<BytesMut>, 36 37 /// Next frame to encode 38 next: Option<Next<B>>, 39 40 /// Last data frame 41 last_data_frame: Option<frame::Data<B>>, 42 43 /// Max frame size, this is specified by the peer 44 max_frame_size: FrameSize, 45 } 46 47 #[derive(Debug)] 48 enum Next<B> { 49 Data(frame::Data<B>), 50 Continuation(frame::Continuation), 51 } 52 53 /// Initialze the connection with this amount of write buffer. 54 /// 55 /// The minimum MAX_FRAME_SIZE is 16kb, so always be able to send a HEADERS 56 /// frame that big. 57 const DEFAULT_BUFFER_CAPACITY: usize = 16 * 1_024; 58 59 /// Min buffer required to attempt to write a frame 60 const MIN_BUFFER_CAPACITY: usize = frame::HEADER_LEN + CHAIN_THRESHOLD; 61 62 /// Chain payloads bigger than this. The remote will never advertise a max frame 63 /// size less than this (well, the spec says the max frame size can't be less 64 /// than 16kb, so not even close). 65 const CHAIN_THRESHOLD: usize = 256; 66 67 // TODO: Make generic 68 impl<T, B> FramedWrite<T, B> 69 where 70 T: AsyncWrite + Unpin, 71 B: Buf, 72 { new(inner: T) -> FramedWrite<T, B>73 pub fn new(inner: T) -> FramedWrite<T, B> { 74 FramedWrite { 75 inner, 76 hpack: hpack::Encoder::default(), 77 buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)), 78 next: None, 79 last_data_frame: None, 80 max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE, 81 } 82 } 83 84 /// Returns `Ready` when `send` is able to accept a frame 85 /// 86 /// Calling this function may result in the current contents of the buffer 87 /// to be flushed to `T`. poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>>88 pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { 89 if !self.has_capacity() { 90 // Try flushing 91 ready!(self.flush(cx))?; 92 93 if !self.has_capacity() { 94 return Poll::Pending; 95 } 96 } 97 98 Poll::Ready(Ok(())) 99 } 100 101 /// Buffer a frame. 102 /// 103 /// `poll_ready` must be called first to ensure that a frame may be 104 /// accepted. buffer(&mut self, item: Frame<B>) -> Result<(), UserError>105 pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> { 106 // Ensure that we have enough capacity to accept the write. 107 assert!(self.has_capacity()); 108 109 log::debug!("send; frame={:?}", item); 110 111 match item { 112 Frame::Data(mut v) => { 113 // Ensure that the payload is not greater than the max frame. 114 let len = v.payload().remaining(); 115 116 if len > self.max_frame_size() { 117 return Err(PayloadTooBig); 118 } 119 120 if len >= CHAIN_THRESHOLD { 121 let head = v.head(); 122 123 // Encode the frame head to the buffer 124 head.encode(len, self.buf.get_mut()); 125 126 // Save the data frame 127 self.next = Some(Next::Data(v)); 128 } else { 129 v.encode_chunk(self.buf.get_mut()); 130 131 // The chunk has been fully encoded, so there is no need to 132 // keep it around 133 assert_eq!(v.payload().remaining(), 0, "chunk not fully encoded"); 134 135 // Save off the last frame... 136 self.last_data_frame = Some(v); 137 } 138 } 139 Frame::Headers(v) => { 140 let mut buf = limited_write_buf!(self); 141 if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) { 142 self.next = Some(Next::Continuation(continuation)); 143 } 144 } 145 Frame::PushPromise(v) => { 146 let mut buf = limited_write_buf!(self); 147 if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) { 148 self.next = Some(Next::Continuation(continuation)); 149 } 150 } 151 Frame::Settings(v) => { 152 v.encode(self.buf.get_mut()); 153 log::trace!("encoded settings; rem={:?}", self.buf.remaining()); 154 } 155 Frame::GoAway(v) => { 156 v.encode(self.buf.get_mut()); 157 log::trace!("encoded go_away; rem={:?}", self.buf.remaining()); 158 } 159 Frame::Ping(v) => { 160 v.encode(self.buf.get_mut()); 161 log::trace!("encoded ping; rem={:?}", self.buf.remaining()); 162 } 163 Frame::WindowUpdate(v) => { 164 v.encode(self.buf.get_mut()); 165 log::trace!("encoded window_update; rem={:?}", self.buf.remaining()); 166 } 167 168 Frame::Priority(_) => { 169 /* 170 v.encode(self.buf.get_mut()); 171 log::trace!("encoded priority; rem={:?}", self.buf.remaining()); 172 */ 173 unimplemented!(); 174 } 175 Frame::Reset(v) => { 176 v.encode(self.buf.get_mut()); 177 log::trace!("encoded reset; rem={:?}", self.buf.remaining()); 178 } 179 } 180 181 Ok(()) 182 } 183 184 /// Flush buffered data to the wire flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>>185 pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { 186 log::trace!("flush"); 187 188 loop { 189 while !self.is_empty() { 190 match self.next { 191 Some(Next::Data(ref mut frame)) => { 192 log::trace!(" -> queued data frame"); 193 let mut buf = (&mut self.buf).chain(frame.payload_mut()); 194 ready!(Pin::new(&mut self.inner).poll_write_buf(cx, &mut buf))?; 195 } 196 _ => { 197 log::trace!(" -> not a queued data frame"); 198 ready!(Pin::new(&mut self.inner).poll_write_buf(cx, &mut self.buf))?; 199 } 200 } 201 } 202 203 // Clear internal buffer 204 self.buf.set_position(0); 205 self.buf.get_mut().clear(); 206 207 // The data frame has been written, so unset it 208 match self.next.take() { 209 Some(Next::Data(frame)) => { 210 self.last_data_frame = Some(frame); 211 debug_assert!(self.is_empty()); 212 break; 213 } 214 Some(Next::Continuation(frame)) => { 215 // Buffer the continuation frame, then try to write again 216 let mut buf = limited_write_buf!(self); 217 if let Some(continuation) = frame.encode(&mut self.hpack, &mut buf) { 218 // We previously had a CONTINUATION, and after encoding 219 // it, we got *another* one? Let's just double check 220 // that at least some progress is being made... 221 if self.buf.get_ref().len() == frame::HEADER_LEN { 222 // If *only* the CONTINUATION frame header was 223 // written, and *no* header fields, we're stuck 224 // in a loop... 225 panic!("CONTINUATION frame write loop; header value too big to encode"); 226 } 227 228 self.next = Some(Next::Continuation(continuation)); 229 } 230 } 231 None => { 232 break; 233 } 234 } 235 } 236 237 log::trace!("flushing buffer"); 238 // Flush the upstream 239 ready!(Pin::new(&mut self.inner).poll_flush(cx))?; 240 241 Poll::Ready(Ok(())) 242 } 243 244 /// Close the codec shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>>245 pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { 246 ready!(self.flush(cx))?; 247 Pin::new(&mut self.inner).poll_shutdown(cx) 248 } 249 has_capacity(&self) -> bool250 fn has_capacity(&self) -> bool { 251 self.next.is_none() && self.buf.get_ref().remaining_mut() >= MIN_BUFFER_CAPACITY 252 } 253 is_empty(&self) -> bool254 fn is_empty(&self) -> bool { 255 match self.next { 256 Some(Next::Data(ref frame)) => !frame.payload().has_remaining(), 257 _ => !self.buf.has_remaining(), 258 } 259 } 260 } 261 262 impl<T, B> FramedWrite<T, B> { 263 /// Returns the max frame size that can be sent max_frame_size(&self) -> usize264 pub fn max_frame_size(&self) -> usize { 265 self.max_frame_size as usize 266 } 267 268 /// Set the peer's max frame size. set_max_frame_size(&mut self, val: usize)269 pub fn set_max_frame_size(&mut self, val: usize) { 270 assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize); 271 self.max_frame_size = val as FrameSize; 272 } 273 274 /// Retrieve the last data frame that has been sent take_last_data_frame(&mut self) -> Option<frame::Data<B>>275 pub fn take_last_data_frame(&mut self) -> Option<frame::Data<B>> { 276 self.last_data_frame.take() 277 } 278 get_mut(&mut self) -> &mut T279 pub fn get_mut(&mut self) -> &mut T { 280 &mut self.inner 281 } 282 } 283 284 impl<T: AsyncRead + Unpin, B> AsyncRead for FramedWrite<T, B> { prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool285 unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [std::mem::MaybeUninit<u8>]) -> bool { 286 self.inner.prepare_uninitialized_buffer(buf) 287 } 288 poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll<io::Result<usize>>289 fn poll_read( 290 mut self: Pin<&mut Self>, 291 cx: &mut Context<'_>, 292 buf: &mut [u8], 293 ) -> Poll<io::Result<usize>> { 294 Pin::new(&mut self.inner).poll_read(cx, buf) 295 } 296 poll_read_buf<Buf: BufMut>( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut Buf, ) -> Poll<io::Result<usize>>297 fn poll_read_buf<Buf: BufMut>( 298 mut self: Pin<&mut Self>, 299 cx: &mut Context<'_>, 300 buf: &mut Buf, 301 ) -> Poll<io::Result<usize>> { 302 Pin::new(&mut self.inner).poll_read_buf(cx, buf) 303 } 304 } 305 306 // We never project the Pin to `B`. 307 impl<T: Unpin, B> Unpin for FramedWrite<T, B> {} 308 309 #[cfg(feature = "unstable")] 310 mod unstable { 311 use super::*; 312 313 impl<T, B> FramedWrite<T, B> { get_ref(&self) -> &T314 pub fn get_ref(&self) -> &T { 315 &self.inner 316 } 317 } 318 } 319