1 use crate::codec::UserError;
2 use crate::codec::UserError::*;
3 use crate::frame::{self, Frame, FrameSize};
4 use crate::hpack;
5
6 use bytes::{Buf, BufMut, BytesMut};
7 use std::pin::Pin;
8 use std::task::{Context, Poll};
9 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10
11 use std::io::{self, Cursor, IoSlice};
12
13 // A macro to get around a method needing to borrow &mut self
14 macro_rules! limited_write_buf {
15 ($self:expr) => {{
16 let limit = $self.max_frame_size() + frame::HEADER_LEN;
17 $self.buf.get_mut().limit(limit)
18 }};
19 }
20
21 #[derive(Debug)]
22 pub struct FramedWrite<T, B> {
23 /// Upstream `AsyncWrite`
24 inner: T,
25
26 encoder: Encoder<B>,
27 }
28
29 #[derive(Debug)]
30 struct Encoder<B> {
31 /// HPACK encoder
32 hpack: hpack::Encoder,
33
34 /// Write buffer
35 ///
36 /// TODO: Should this be a ring buffer?
37 buf: Cursor<BytesMut>,
38
39 /// Next frame to encode
40 next: Option<Next<B>>,
41
42 /// Last data frame
43 last_data_frame: Option<frame::Data<B>>,
44
45 /// Max frame size, this is specified by the peer
46 max_frame_size: FrameSize,
47
48 /// Whether or not the wrapped `AsyncWrite` supports vectored IO.
49 is_write_vectored: bool,
50 }
51
52 #[derive(Debug)]
53 enum Next<B> {
54 Data(frame::Data<B>),
55 Continuation(frame::Continuation),
56 }
57
58 /// Initialize the connection with this amount of write buffer.
59 ///
60 /// The minimum MAX_FRAME_SIZE is 16kb, so always be able to send a HEADERS
61 /// frame that big.
62 const DEFAULT_BUFFER_CAPACITY: usize = 16 * 1_024;
63
64 /// Min buffer required to attempt to write a frame
65 const MIN_BUFFER_CAPACITY: usize = frame::HEADER_LEN + CHAIN_THRESHOLD;
66
67 /// Chain payloads bigger than this. The remote will never advertise a max frame
68 /// size less than this (well, the spec says the max frame size can't be less
69 /// than 16kb, so not even close).
70 const CHAIN_THRESHOLD: usize = 256;
71
72 // TODO: Make generic
73 impl<T, B> FramedWrite<T, B>
74 where
75 T: AsyncWrite + Unpin,
76 B: Buf,
77 {
new(inner: T) -> FramedWrite<T, B>78 pub fn new(inner: T) -> FramedWrite<T, B> {
79 let is_write_vectored = inner.is_write_vectored();
80 FramedWrite {
81 inner,
82 encoder: Encoder {
83 hpack: hpack::Encoder::default(),
84 buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)),
85 next: None,
86 last_data_frame: None,
87 max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE,
88 is_write_vectored,
89 },
90 }
91 }
92
93 /// Returns `Ready` when `send` is able to accept a frame
94 ///
95 /// Calling this function may result in the current contents of the buffer
96 /// to be flushed to `T`.
poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>>97 pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
98 if !self.encoder.has_capacity() {
99 // Try flushing
100 ready!(self.flush(cx))?;
101
102 if !self.encoder.has_capacity() {
103 return Poll::Pending;
104 }
105 }
106
107 Poll::Ready(Ok(()))
108 }
109
110 /// Buffer a frame.
111 ///
112 /// `poll_ready` must be called first to ensure that a frame may be
113 /// accepted.
buffer(&mut self, item: Frame<B>) -> Result<(), UserError>114 pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
115 self.encoder.buffer(item)
116 }
117
118 /// Flush buffered data to the wire
flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>>119 pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
120 let span = tracing::trace_span!("FramedWrite::flush");
121 let _e = span.enter();
122
123 loop {
124 while !self.encoder.is_empty() {
125 match self.encoder.next {
126 Some(Next::Data(ref mut frame)) => {
127 tracing::trace!(queued_data_frame = true);
128 let mut buf = (&mut self.encoder.buf).chain(frame.payload_mut());
129 ready!(write(
130 &mut self.inner,
131 self.encoder.is_write_vectored,
132 &mut buf,
133 cx,
134 ))?
135 }
136 _ => {
137 tracing::trace!(queued_data_frame = false);
138 ready!(write(
139 &mut self.inner,
140 self.encoder.is_write_vectored,
141 &mut self.encoder.buf,
142 cx,
143 ))?
144 }
145 }
146 }
147
148 match self.encoder.unset_frame() {
149 ControlFlow::Continue => (),
150 ControlFlow::Break => break,
151 }
152 }
153
154 tracing::trace!("flushing buffer");
155 // Flush the upstream
156 ready!(Pin::new(&mut self.inner).poll_flush(cx))?;
157
158 Poll::Ready(Ok(()))
159 }
160
161 /// Close the codec
shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>>162 pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
163 ready!(self.flush(cx))?;
164 Pin::new(&mut self.inner).poll_shutdown(cx)
165 }
166 }
167
write<T, B>( writer: &mut T, is_write_vectored: bool, buf: &mut B, cx: &mut Context<'_>, ) -> Poll<io::Result<()>> where T: AsyncWrite + Unpin, B: Buf,168 fn write<T, B>(
169 writer: &mut T,
170 is_write_vectored: bool,
171 buf: &mut B,
172 cx: &mut Context<'_>,
173 ) -> Poll<io::Result<()>>
174 where
175 T: AsyncWrite + Unpin,
176 B: Buf,
177 {
178 // TODO(eliza): when tokio-util 0.5.1 is released, this
179 // could just use `poll_write_buf`...
180 const MAX_IOVS: usize = 64;
181 let n = if is_write_vectored {
182 let mut bufs = [IoSlice::new(&[]); MAX_IOVS];
183 let cnt = buf.chunks_vectored(&mut bufs);
184 ready!(Pin::new(writer).poll_write_vectored(cx, &bufs[..cnt]))?
185 } else {
186 ready!(Pin::new(writer).poll_write(cx, buf.chunk()))?
187 };
188 buf.advance(n);
189 Ok(()).into()
190 }
191
192 #[must_use]
193 enum ControlFlow {
194 Continue,
195 Break,
196 }
197
198 impl<B> Encoder<B>
199 where
200 B: Buf,
201 {
unset_frame(&mut self) -> ControlFlow202 fn unset_frame(&mut self) -> ControlFlow {
203 // Clear internal buffer
204 self.buf.set_position(0);
205 self.buf.get_mut().clear();
206
207 // The data frame has been written, so unset it
208 match self.next.take() {
209 Some(Next::Data(frame)) => {
210 self.last_data_frame = Some(frame);
211 debug_assert!(self.is_empty());
212 ControlFlow::Break
213 }
214 Some(Next::Continuation(frame)) => {
215 // Buffer the continuation frame, then try to write again
216 let mut buf = limited_write_buf!(self);
217 if let Some(continuation) = frame.encode(&mut buf) {
218 self.next = Some(Next::Continuation(continuation));
219 }
220 ControlFlow::Continue
221 }
222 None => ControlFlow::Break,
223 }
224 }
225
buffer(&mut self, item: Frame<B>) -> Result<(), UserError>226 fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
227 // Ensure that we have enough capacity to accept the write.
228 assert!(self.has_capacity());
229 let span = tracing::trace_span!("FramedWrite::buffer", frame = ?item);
230 let _e = span.enter();
231
232 tracing::debug!(frame = ?item, "send");
233
234 match item {
235 Frame::Data(mut v) => {
236 // Ensure that the payload is not greater than the max frame.
237 let len = v.payload().remaining();
238
239 if len > self.max_frame_size() {
240 return Err(PayloadTooBig);
241 }
242
243 if len >= CHAIN_THRESHOLD {
244 let head = v.head();
245
246 // Encode the frame head to the buffer
247 head.encode(len, self.buf.get_mut());
248
249 // Save the data frame
250 self.next = Some(Next::Data(v));
251 } else {
252 v.encode_chunk(self.buf.get_mut());
253
254 // The chunk has been fully encoded, so there is no need to
255 // keep it around
256 assert_eq!(v.payload().remaining(), 0, "chunk not fully encoded");
257
258 // Save off the last frame...
259 self.last_data_frame = Some(v);
260 }
261 }
262 Frame::Headers(v) => {
263 let mut buf = limited_write_buf!(self);
264 if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) {
265 self.next = Some(Next::Continuation(continuation));
266 }
267 }
268 Frame::PushPromise(v) => {
269 let mut buf = limited_write_buf!(self);
270 if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) {
271 self.next = Some(Next::Continuation(continuation));
272 }
273 }
274 Frame::Settings(v) => {
275 v.encode(self.buf.get_mut());
276 tracing::trace!(rem = self.buf.remaining(), "encoded settings");
277 }
278 Frame::GoAway(v) => {
279 v.encode(self.buf.get_mut());
280 tracing::trace!(rem = self.buf.remaining(), "encoded go_away");
281 }
282 Frame::Ping(v) => {
283 v.encode(self.buf.get_mut());
284 tracing::trace!(rem = self.buf.remaining(), "encoded ping");
285 }
286 Frame::WindowUpdate(v) => {
287 v.encode(self.buf.get_mut());
288 tracing::trace!(rem = self.buf.remaining(), "encoded window_update");
289 }
290
291 Frame::Priority(_) => {
292 /*
293 v.encode(self.buf.get_mut());
294 tracing::trace!("encoded priority; rem={:?}", self.buf.remaining());
295 */
296 unimplemented!();
297 }
298 Frame::Reset(v) => {
299 v.encode(self.buf.get_mut());
300 tracing::trace!(rem = self.buf.remaining(), "encoded reset");
301 }
302 }
303
304 Ok(())
305 }
306
has_capacity(&self) -> bool307 fn has_capacity(&self) -> bool {
308 self.next.is_none() && self.buf.get_ref().remaining_mut() >= MIN_BUFFER_CAPACITY
309 }
310
is_empty(&self) -> bool311 fn is_empty(&self) -> bool {
312 match self.next {
313 Some(Next::Data(ref frame)) => !frame.payload().has_remaining(),
314 _ => !self.buf.has_remaining(),
315 }
316 }
317 }
318
319 impl<B> Encoder<B> {
max_frame_size(&self) -> usize320 fn max_frame_size(&self) -> usize {
321 self.max_frame_size as usize
322 }
323 }
324
325 impl<T, B> FramedWrite<T, B> {
326 /// Returns the max frame size that can be sent
max_frame_size(&self) -> usize327 pub fn max_frame_size(&self) -> usize {
328 self.encoder.max_frame_size()
329 }
330
331 /// Set the peer's max frame size.
set_max_frame_size(&mut self, val: usize)332 pub fn set_max_frame_size(&mut self, val: usize) {
333 assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize);
334 self.encoder.max_frame_size = val as FrameSize;
335 }
336
337 /// Set the peer's header table size.
set_header_table_size(&mut self, val: usize)338 pub fn set_header_table_size(&mut self, val: usize) {
339 self.encoder.hpack.update_max_size(val);
340 }
341
342 /// Retrieve the last data frame that has been sent
take_last_data_frame(&mut self) -> Option<frame::Data<B>>343 pub fn take_last_data_frame(&mut self) -> Option<frame::Data<B>> {
344 self.encoder.last_data_frame.take()
345 }
346
get_mut(&mut self) -> &mut T347 pub fn get_mut(&mut self) -> &mut T {
348 &mut self.inner
349 }
350 }
351
352 impl<T: AsyncRead + Unpin, B> AsyncRead for FramedWrite<T, B> {
poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf, ) -> Poll<io::Result<()>>353 fn poll_read(
354 mut self: Pin<&mut Self>,
355 cx: &mut Context<'_>,
356 buf: &mut ReadBuf,
357 ) -> Poll<io::Result<()>> {
358 Pin::new(&mut self.inner).poll_read(cx, buf)
359 }
360 }
361
362 // We never project the Pin to `B`.
363 impl<T: Unpin, B> Unpin for FramedWrite<T, B> {}
364
365 #[cfg(feature = "unstable")]
366 mod unstable {
367 use super::*;
368
369 impl<T, B> FramedWrite<T, B> {
get_ref(&self) -> &T370 pub fn get_ref(&self) -> &T {
371 &self.inner
372 }
373 }
374 }
375