1 use crate::codec::RecvError;
2 use crate::frame::{self, Frame, Kind, Reason};
3 use crate::frame::{
4     DEFAULT_MAX_FRAME_SIZE, DEFAULT_SETTINGS_HEADER_TABLE_SIZE, MAX_MAX_FRAME_SIZE,
5 };
6 
7 use crate::hpack;
8 
9 use futures_core::Stream;
10 
11 use bytes::BytesMut;
12 
13 use std::io;
14 
15 use std::pin::Pin;
16 use std::task::{Context, Poll};
17 use tokio::io::AsyncRead;
18 use tokio_util::codec::FramedRead as InnerFramedRead;
19 use tokio_util::codec::{LengthDelimitedCodec, LengthDelimitedCodecError};
20 
21 // 16 MB "sane default" taken from golang http2
22 const DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE: usize = 16 << 20;
23 
24 #[derive(Debug)]
25 pub struct FramedRead<T> {
26     inner: InnerFramedRead<T, LengthDelimitedCodec>,
27 
28     // hpack decoder state
29     hpack: hpack::Decoder,
30 
31     max_header_list_size: usize,
32 
33     partial: Option<Partial>,
34 }
35 
36 /// Partially loaded headers frame
37 #[derive(Debug)]
38 struct Partial {
39     /// Empty frame
40     frame: Continuable,
41 
42     /// Partial header payload
43     buf: BytesMut,
44 }
45 
46 #[derive(Debug)]
47 enum Continuable {
48     Headers(frame::Headers),
49     PushPromise(frame::PushPromise),
50 }
51 
52 impl<T> FramedRead<T> {
new(inner: InnerFramedRead<T, LengthDelimitedCodec>) -> FramedRead<T>53     pub fn new(inner: InnerFramedRead<T, LengthDelimitedCodec>) -> FramedRead<T> {
54         FramedRead {
55             inner,
56             hpack: hpack::Decoder::new(DEFAULT_SETTINGS_HEADER_TABLE_SIZE),
57             max_header_list_size: DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE,
58             partial: None,
59         }
60     }
61 
decode_frame(&mut self, mut bytes: BytesMut) -> Result<Option<Frame>, RecvError>62     fn decode_frame(&mut self, mut bytes: BytesMut) -> Result<Option<Frame>, RecvError> {
63         use self::RecvError::*;
64 
65         tracing::trace!("decoding frame from {}B", bytes.len());
66 
67         // Parse the head
68         let head = frame::Head::parse(&bytes);
69 
70         if self.partial.is_some() && head.kind() != Kind::Continuation {
71             proto_err!(conn: "expected CONTINUATION, got {:?}", head.kind());
72             return Err(Connection(Reason::PROTOCOL_ERROR));
73         }
74 
75         let kind = head.kind();
76 
77         tracing::trace!("    -> kind={:?}", kind);
78 
79         macro_rules! header_block {
80             ($frame:ident, $head:ident, $bytes:ident) => ({
81                 // Drop the frame header
82                 // TODO: Change to drain: carllerche/bytes#130
83                 let _ = $bytes.split_to(frame::HEADER_LEN);
84 
85                 // Parse the header frame w/o parsing the payload
86                 let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) {
87                     Ok(res) => res,
88                     Err(frame::Error::InvalidDependencyId) => {
89                         proto_err!(stream: "invalid HEADERS dependency ID");
90                         // A stream cannot depend on itself. An endpoint MUST
91                         // treat this as a stream error (Section 5.4.2) of type
92                         // `PROTOCOL_ERROR`.
93                         return Err(Stream {
94                             id: $head.stream_id(),
95                             reason: Reason::PROTOCOL_ERROR,
96                         });
97                     },
98                     Err(e) => {
99                         proto_err!(conn: "failed to load frame; err={:?}", e);
100                         return Err(Connection(Reason::PROTOCOL_ERROR));
101                     }
102                 };
103 
104                 let is_end_headers = frame.is_end_headers();
105 
106                 // Load the HPACK encoded headers
107                 match frame.load_hpack(&mut payload, self.max_header_list_size, &mut self.hpack) {
108                     Ok(_) => {},
109                     Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {},
110                     Err(frame::Error::MalformedMessage) => {
111                         let id = $head.stream_id();
112                         proto_err!(stream: "malformed header block; stream={:?}", id);
113                         return Err(Stream {
114                             id,
115                             reason: Reason::PROTOCOL_ERROR,
116                         });
117                     },
118                     Err(e) => {
119                         proto_err!(conn: "failed HPACK decoding; err={:?}", e);
120                         return Err(Connection(Reason::PROTOCOL_ERROR));
121                     }
122                 }
123 
124                 if is_end_headers {
125                     frame.into()
126                 } else {
127                     tracing::trace!("loaded partial header block");
128                     // Defer returning the frame
129                     self.partial = Some(Partial {
130                         frame: Continuable::$frame(frame),
131                         buf: payload,
132                     });
133 
134                     return Ok(None);
135                 }
136             });
137         }
138 
139         let frame = match kind {
140             Kind::Settings => {
141                 let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]);
142 
143                 res.map_err(|e| {
144                     proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e);
145                     Connection(Reason::PROTOCOL_ERROR)
146                 })?
147                 .into()
148             }
149             Kind::Ping => {
150                 let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]);
151 
152                 res.map_err(|e| {
153                     proto_err!(conn: "failed to load PING frame; err={:?}", e);
154                     Connection(Reason::PROTOCOL_ERROR)
155                 })?
156                 .into()
157             }
158             Kind::WindowUpdate => {
159                 let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]);
160 
161                 res.map_err(|e| {
162                     proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e);
163                     Connection(Reason::PROTOCOL_ERROR)
164                 })?
165                 .into()
166             }
167             Kind::Data => {
168                 let _ = bytes.split_to(frame::HEADER_LEN);
169                 let res = frame::Data::load(head, bytes.freeze());
170 
171                 // TODO: Should this always be connection level? Probably not...
172                 res.map_err(|e| {
173                     proto_err!(conn: "failed to load DATA frame; err={:?}", e);
174                     Connection(Reason::PROTOCOL_ERROR)
175                 })?
176                 .into()
177             }
178             Kind::Headers => header_block!(Headers, head, bytes),
179             Kind::Reset => {
180                 let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]);
181                 res.map_err(|e| {
182                     proto_err!(conn: "failed to load RESET frame; err={:?}", e);
183                     Connection(Reason::PROTOCOL_ERROR)
184                 })?
185                 .into()
186             }
187             Kind::GoAway => {
188                 let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]);
189                 res.map_err(|e| {
190                     proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e);
191                     Connection(Reason::PROTOCOL_ERROR)
192                 })?
193                 .into()
194             }
195             Kind::PushPromise => header_block!(PushPromise, head, bytes),
196             Kind::Priority => {
197                 if head.stream_id() == 0 {
198                     // Invalid stream identifier
199                     proto_err!(conn: "invalid stream ID 0");
200                     return Err(Connection(Reason::PROTOCOL_ERROR));
201                 }
202 
203                 match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) {
204                     Ok(frame) => frame.into(),
205                     Err(frame::Error::InvalidDependencyId) => {
206                         // A stream cannot depend on itself. An endpoint MUST
207                         // treat this as a stream error (Section 5.4.2) of type
208                         // `PROTOCOL_ERROR`.
209                         let id = head.stream_id();
210                         proto_err!(stream: "PRIORITY invalid dependency ID; stream={:?}", id);
211                         return Err(Stream {
212                             id,
213                             reason: Reason::PROTOCOL_ERROR,
214                         });
215                     }
216                     Err(e) => {
217                         proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e);
218                         return Err(Connection(Reason::PROTOCOL_ERROR));
219                     }
220                 }
221             }
222             Kind::Continuation => {
223                 let is_end_headers = (head.flag() & 0x4) == 0x4;
224 
225                 let mut partial = match self.partial.take() {
226                     Some(partial) => partial,
227                     None => {
228                         proto_err!(conn: "received unexpected CONTINUATION frame");
229                         return Err(Connection(Reason::PROTOCOL_ERROR));
230                     }
231                 };
232 
233                 // The stream identifiers must match
234                 if partial.frame.stream_id() != head.stream_id() {
235                     proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID");
236                     return Err(Connection(Reason::PROTOCOL_ERROR));
237                 }
238 
239                 // Extend the buf
240                 if partial.buf.is_empty() {
241                     partial.buf = bytes.split_off(frame::HEADER_LEN);
242                 } else {
243                     if partial.frame.is_over_size() {
244                         // If there was left over bytes previously, they may be
245                         // needed to continue decoding, even though we will
246                         // be ignoring this frame. This is done to keep the HPACK
247                         // decoder state up-to-date.
248                         //
249                         // Still, we need to be careful, because if a malicious
250                         // attacker were to try to send a gigantic string, such
251                         // that it fits over multiple header blocks, we could
252                         // grow memory uncontrollably again, and that'd be a shame.
253                         //
254                         // Instead, we use a simple heuristic to determine if
255                         // we should continue to ignore decoding, or to tell
256                         // the attacker to go away.
257                         if partial.buf.len() + bytes.len() > self.max_header_list_size {
258                             proto_err!(conn: "CONTINUATION frame header block size over ignorable limit");
259                             return Err(Connection(Reason::COMPRESSION_ERROR));
260                         }
261                     }
262                     partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]);
263                 }
264 
265                 match partial.frame.load_hpack(
266                     &mut partial.buf,
267                     self.max_header_list_size,
268                     &mut self.hpack,
269                 ) {
270                     Ok(_) => {}
271                     Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_)))
272                         if !is_end_headers => {}
273                     Err(frame::Error::MalformedMessage) => {
274                         let id = head.stream_id();
275                         proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id);
276                         return Err(Stream {
277                             id,
278                             reason: Reason::PROTOCOL_ERROR,
279                         });
280                     }
281                     Err(e) => {
282                         proto_err!(conn: "failed HPACK decoding; err={:?}", e);
283                         return Err(Connection(Reason::PROTOCOL_ERROR));
284                     }
285                 }
286 
287                 if is_end_headers {
288                     partial.frame.into()
289                 } else {
290                     self.partial = Some(partial);
291                     return Ok(None);
292                 }
293             }
294             Kind::Unknown => {
295                 // Unknown frames are ignored
296                 return Ok(None);
297             }
298         };
299 
300         Ok(Some(frame))
301     }
302 
get_ref(&self) -> &T303     pub fn get_ref(&self) -> &T {
304         self.inner.get_ref()
305     }
306 
get_mut(&mut self) -> &mut T307     pub fn get_mut(&mut self) -> &mut T {
308         self.inner.get_mut()
309     }
310 
311     /// Returns the current max frame size setting
312     #[cfg(feature = "unstable")]
313     #[inline]
max_frame_size(&self) -> usize314     pub fn max_frame_size(&self) -> usize {
315         self.inner.decoder().max_frame_length()
316     }
317 
318     /// Updates the max frame size setting.
319     ///
320     /// Must be within 16,384 and 16,777,215.
321     #[inline]
set_max_frame_size(&mut self, val: usize)322     pub fn set_max_frame_size(&mut self, val: usize) {
323         assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize);
324         self.inner.decoder_mut().set_max_frame_length(val)
325     }
326 
327     /// Update the max header list size setting.
328     #[inline]
set_max_header_list_size(&mut self, val: usize)329     pub fn set_max_header_list_size(&mut self, val: usize) {
330         self.max_header_list_size = val;
331     }
332 }
333 
334 impl<T> Stream for FramedRead<T>
335 where
336     T: AsyncRead + Unpin,
337 {
338     type Item = Result<Frame, RecvError>;
339 
poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>340     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
341         loop {
342             tracing::trace!("poll");
343             let bytes = match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
344                 Some(Ok(bytes)) => bytes,
345                 Some(Err(e)) => return Poll::Ready(Some(Err(map_err(e)))),
346                 None => return Poll::Ready(None),
347             };
348 
349             tracing::trace!("poll; bytes={}B", bytes.len());
350             if let Some(frame) = self.decode_frame(bytes)? {
351                 tracing::debug!("received; frame={:?}", frame);
352                 return Poll::Ready(Some(Ok(frame)));
353             }
354         }
355     }
356 }
357 
map_err(err: io::Error) -> RecvError358 fn map_err(err: io::Error) -> RecvError {
359     if let io::ErrorKind::InvalidData = err.kind() {
360         if let Some(custom) = err.get_ref() {
361             if custom.is::<LengthDelimitedCodecError>() {
362                 return RecvError::Connection(Reason::FRAME_SIZE_ERROR);
363             }
364         }
365     }
366     err.into()
367 }
368 
369 // ===== impl Continuable =====
370 
371 impl Continuable {
stream_id(&self) -> frame::StreamId372     fn stream_id(&self) -> frame::StreamId {
373         match *self {
374             Continuable::Headers(ref h) => h.stream_id(),
375             Continuable::PushPromise(ref p) => p.stream_id(),
376         }
377     }
378 
is_over_size(&self) -> bool379     fn is_over_size(&self) -> bool {
380         match *self {
381             Continuable::Headers(ref h) => h.is_over_size(),
382             Continuable::PushPromise(ref p) => p.is_over_size(),
383         }
384     }
385 
load_hpack( &mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder, ) -> Result<(), frame::Error>386     fn load_hpack(
387         &mut self,
388         src: &mut BytesMut,
389         max_header_list_size: usize,
390         decoder: &mut hpack::Decoder,
391     ) -> Result<(), frame::Error> {
392         match *self {
393             Continuable::Headers(ref mut h) => h.load_hpack(src, max_header_list_size, decoder),
394             Continuable::PushPromise(ref mut p) => p.load_hpack(src, max_header_list_size, decoder),
395         }
396     }
397 }
398 
399 impl<T> From<Continuable> for Frame<T> {
from(cont: Continuable) -> Self400     fn from(cont: Continuable) -> Self {
401         match cont {
402             Continuable::Headers(mut headers) => {
403                 headers.set_end_headers();
404                 headers.into()
405             }
406             Continuable::PushPromise(mut push) => {
407                 push.set_end_headers();
408                 push.into()
409             }
410         }
411     }
412 }
413