1 use crate::codec::RecvError;
2 use crate::frame::{self, Frame, Kind, Reason};
3 use crate::frame::{
4 DEFAULT_MAX_FRAME_SIZE, DEFAULT_SETTINGS_HEADER_TABLE_SIZE, MAX_MAX_FRAME_SIZE,
5 };
6
7 use crate::hpack;
8
9 use futures_core::Stream;
10
11 use bytes::BytesMut;
12
13 use std::io;
14
15 use std::pin::Pin;
16 use std::task::{Context, Poll};
17 use tokio::io::AsyncRead;
18 use tokio_util::codec::FramedRead as InnerFramedRead;
19 use tokio_util::codec::{LengthDelimitedCodec, LengthDelimitedCodecError};
20
21 // 16 MB "sane default" taken from golang http2
22 const DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE: usize = 16 << 20;
23
24 #[derive(Debug)]
25 pub struct FramedRead<T> {
26 inner: InnerFramedRead<T, LengthDelimitedCodec>,
27
28 // hpack decoder state
29 hpack: hpack::Decoder,
30
31 max_header_list_size: usize,
32
33 partial: Option<Partial>,
34 }
35
36 /// Partially loaded headers frame
37 #[derive(Debug)]
38 struct Partial {
39 /// Empty frame
40 frame: Continuable,
41
42 /// Partial header payload
43 buf: BytesMut,
44 }
45
46 #[derive(Debug)]
47 enum Continuable {
48 Headers(frame::Headers),
49 PushPromise(frame::PushPromise),
50 }
51
52 impl<T> FramedRead<T> {
new(inner: InnerFramedRead<T, LengthDelimitedCodec>) -> FramedRead<T>53 pub fn new(inner: InnerFramedRead<T, LengthDelimitedCodec>) -> FramedRead<T> {
54 FramedRead {
55 inner,
56 hpack: hpack::Decoder::new(DEFAULT_SETTINGS_HEADER_TABLE_SIZE),
57 max_header_list_size: DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE,
58 partial: None,
59 }
60 }
61
decode_frame(&mut self, mut bytes: BytesMut) -> Result<Option<Frame>, RecvError>62 fn decode_frame(&mut self, mut bytes: BytesMut) -> Result<Option<Frame>, RecvError> {
63 use self::RecvError::*;
64
65 tracing::trace!("decoding frame from {}B", bytes.len());
66
67 // Parse the head
68 let head = frame::Head::parse(&bytes);
69
70 if self.partial.is_some() && head.kind() != Kind::Continuation {
71 proto_err!(conn: "expected CONTINUATION, got {:?}", head.kind());
72 return Err(Connection(Reason::PROTOCOL_ERROR));
73 }
74
75 let kind = head.kind();
76
77 tracing::trace!(" -> kind={:?}", kind);
78
79 macro_rules! header_block {
80 ($frame:ident, $head:ident, $bytes:ident) => ({
81 // Drop the frame header
82 // TODO: Change to drain: carllerche/bytes#130
83 let _ = $bytes.split_to(frame::HEADER_LEN);
84
85 // Parse the header frame w/o parsing the payload
86 let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) {
87 Ok(res) => res,
88 Err(frame::Error::InvalidDependencyId) => {
89 proto_err!(stream: "invalid HEADERS dependency ID");
90 // A stream cannot depend on itself. An endpoint MUST
91 // treat this as a stream error (Section 5.4.2) of type
92 // `PROTOCOL_ERROR`.
93 return Err(Stream {
94 id: $head.stream_id(),
95 reason: Reason::PROTOCOL_ERROR,
96 });
97 },
98 Err(e) => {
99 proto_err!(conn: "failed to load frame; err={:?}", e);
100 return Err(Connection(Reason::PROTOCOL_ERROR));
101 }
102 };
103
104 let is_end_headers = frame.is_end_headers();
105
106 // Load the HPACK encoded headers
107 match frame.load_hpack(&mut payload, self.max_header_list_size, &mut self.hpack) {
108 Ok(_) => {},
109 Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {},
110 Err(frame::Error::MalformedMessage) => {
111 let id = $head.stream_id();
112 proto_err!(stream: "malformed header block; stream={:?}", id);
113 return Err(Stream {
114 id,
115 reason: Reason::PROTOCOL_ERROR,
116 });
117 },
118 Err(e) => {
119 proto_err!(conn: "failed HPACK decoding; err={:?}", e);
120 return Err(Connection(Reason::PROTOCOL_ERROR));
121 }
122 }
123
124 if is_end_headers {
125 frame.into()
126 } else {
127 tracing::trace!("loaded partial header block");
128 // Defer returning the frame
129 self.partial = Some(Partial {
130 frame: Continuable::$frame(frame),
131 buf: payload,
132 });
133
134 return Ok(None);
135 }
136 });
137 }
138
139 let frame = match kind {
140 Kind::Settings => {
141 let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]);
142
143 res.map_err(|e| {
144 proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e);
145 Connection(Reason::PROTOCOL_ERROR)
146 })?
147 .into()
148 }
149 Kind::Ping => {
150 let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]);
151
152 res.map_err(|e| {
153 proto_err!(conn: "failed to load PING frame; err={:?}", e);
154 Connection(Reason::PROTOCOL_ERROR)
155 })?
156 .into()
157 }
158 Kind::WindowUpdate => {
159 let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]);
160
161 res.map_err(|e| {
162 proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e);
163 Connection(Reason::PROTOCOL_ERROR)
164 })?
165 .into()
166 }
167 Kind::Data => {
168 let _ = bytes.split_to(frame::HEADER_LEN);
169 let res = frame::Data::load(head, bytes.freeze());
170
171 // TODO: Should this always be connection level? Probably not...
172 res.map_err(|e| {
173 proto_err!(conn: "failed to load DATA frame; err={:?}", e);
174 Connection(Reason::PROTOCOL_ERROR)
175 })?
176 .into()
177 }
178 Kind::Headers => header_block!(Headers, head, bytes),
179 Kind::Reset => {
180 let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]);
181 res.map_err(|e| {
182 proto_err!(conn: "failed to load RESET frame; err={:?}", e);
183 Connection(Reason::PROTOCOL_ERROR)
184 })?
185 .into()
186 }
187 Kind::GoAway => {
188 let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]);
189 res.map_err(|e| {
190 proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e);
191 Connection(Reason::PROTOCOL_ERROR)
192 })?
193 .into()
194 }
195 Kind::PushPromise => header_block!(PushPromise, head, bytes),
196 Kind::Priority => {
197 if head.stream_id() == 0 {
198 // Invalid stream identifier
199 proto_err!(conn: "invalid stream ID 0");
200 return Err(Connection(Reason::PROTOCOL_ERROR));
201 }
202
203 match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) {
204 Ok(frame) => frame.into(),
205 Err(frame::Error::InvalidDependencyId) => {
206 // A stream cannot depend on itself. An endpoint MUST
207 // treat this as a stream error (Section 5.4.2) of type
208 // `PROTOCOL_ERROR`.
209 let id = head.stream_id();
210 proto_err!(stream: "PRIORITY invalid dependency ID; stream={:?}", id);
211 return Err(Stream {
212 id,
213 reason: Reason::PROTOCOL_ERROR,
214 });
215 }
216 Err(e) => {
217 proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e);
218 return Err(Connection(Reason::PROTOCOL_ERROR));
219 }
220 }
221 }
222 Kind::Continuation => {
223 let is_end_headers = (head.flag() & 0x4) == 0x4;
224
225 let mut partial = match self.partial.take() {
226 Some(partial) => partial,
227 None => {
228 proto_err!(conn: "received unexpected CONTINUATION frame");
229 return Err(Connection(Reason::PROTOCOL_ERROR));
230 }
231 };
232
233 // The stream identifiers must match
234 if partial.frame.stream_id() != head.stream_id() {
235 proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID");
236 return Err(Connection(Reason::PROTOCOL_ERROR));
237 }
238
239 // Extend the buf
240 if partial.buf.is_empty() {
241 partial.buf = bytes.split_off(frame::HEADER_LEN);
242 } else {
243 if partial.frame.is_over_size() {
244 // If there was left over bytes previously, they may be
245 // needed to continue decoding, even though we will
246 // be ignoring this frame. This is done to keep the HPACK
247 // decoder state up-to-date.
248 //
249 // Still, we need to be careful, because if a malicious
250 // attacker were to try to send a gigantic string, such
251 // that it fits over multiple header blocks, we could
252 // grow memory uncontrollably again, and that'd be a shame.
253 //
254 // Instead, we use a simple heuristic to determine if
255 // we should continue to ignore decoding, or to tell
256 // the attacker to go away.
257 if partial.buf.len() + bytes.len() > self.max_header_list_size {
258 proto_err!(conn: "CONTINUATION frame header block size over ignorable limit");
259 return Err(Connection(Reason::COMPRESSION_ERROR));
260 }
261 }
262 partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]);
263 }
264
265 match partial.frame.load_hpack(
266 &mut partial.buf,
267 self.max_header_list_size,
268 &mut self.hpack,
269 ) {
270 Ok(_) => {}
271 Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_)))
272 if !is_end_headers => {}
273 Err(frame::Error::MalformedMessage) => {
274 let id = head.stream_id();
275 proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id);
276 return Err(Stream {
277 id,
278 reason: Reason::PROTOCOL_ERROR,
279 });
280 }
281 Err(e) => {
282 proto_err!(conn: "failed HPACK decoding; err={:?}", e);
283 return Err(Connection(Reason::PROTOCOL_ERROR));
284 }
285 }
286
287 if is_end_headers {
288 partial.frame.into()
289 } else {
290 self.partial = Some(partial);
291 return Ok(None);
292 }
293 }
294 Kind::Unknown => {
295 // Unknown frames are ignored
296 return Ok(None);
297 }
298 };
299
300 Ok(Some(frame))
301 }
302
get_ref(&self) -> &T303 pub fn get_ref(&self) -> &T {
304 self.inner.get_ref()
305 }
306
get_mut(&mut self) -> &mut T307 pub fn get_mut(&mut self) -> &mut T {
308 self.inner.get_mut()
309 }
310
311 /// Returns the current max frame size setting
312 #[cfg(feature = "unstable")]
313 #[inline]
max_frame_size(&self) -> usize314 pub fn max_frame_size(&self) -> usize {
315 self.inner.decoder().max_frame_length()
316 }
317
318 /// Updates the max frame size setting.
319 ///
320 /// Must be within 16,384 and 16,777,215.
321 #[inline]
set_max_frame_size(&mut self, val: usize)322 pub fn set_max_frame_size(&mut self, val: usize) {
323 assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize);
324 self.inner.decoder_mut().set_max_frame_length(val)
325 }
326
327 /// Update the max header list size setting.
328 #[inline]
set_max_header_list_size(&mut self, val: usize)329 pub fn set_max_header_list_size(&mut self, val: usize) {
330 self.max_header_list_size = val;
331 }
332 }
333
334 impl<T> Stream for FramedRead<T>
335 where
336 T: AsyncRead + Unpin,
337 {
338 type Item = Result<Frame, RecvError>;
339
poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>340 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
341 loop {
342 tracing::trace!("poll");
343 let bytes = match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
344 Some(Ok(bytes)) => bytes,
345 Some(Err(e)) => return Poll::Ready(Some(Err(map_err(e)))),
346 None => return Poll::Ready(None),
347 };
348
349 tracing::trace!("poll; bytes={}B", bytes.len());
350 if let Some(frame) = self.decode_frame(bytes)? {
351 tracing::debug!("received; frame={:?}", frame);
352 return Poll::Ready(Some(Ok(frame)));
353 }
354 }
355 }
356 }
357
map_err(err: io::Error) -> RecvError358 fn map_err(err: io::Error) -> RecvError {
359 if let io::ErrorKind::InvalidData = err.kind() {
360 if let Some(custom) = err.get_ref() {
361 if custom.is::<LengthDelimitedCodecError>() {
362 return RecvError::Connection(Reason::FRAME_SIZE_ERROR);
363 }
364 }
365 }
366 err.into()
367 }
368
369 // ===== impl Continuable =====
370
371 impl Continuable {
stream_id(&self) -> frame::StreamId372 fn stream_id(&self) -> frame::StreamId {
373 match *self {
374 Continuable::Headers(ref h) => h.stream_id(),
375 Continuable::PushPromise(ref p) => p.stream_id(),
376 }
377 }
378
is_over_size(&self) -> bool379 fn is_over_size(&self) -> bool {
380 match *self {
381 Continuable::Headers(ref h) => h.is_over_size(),
382 Continuable::PushPromise(ref p) => p.is_over_size(),
383 }
384 }
385
load_hpack( &mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder, ) -> Result<(), frame::Error>386 fn load_hpack(
387 &mut self,
388 src: &mut BytesMut,
389 max_header_list_size: usize,
390 decoder: &mut hpack::Decoder,
391 ) -> Result<(), frame::Error> {
392 match *self {
393 Continuable::Headers(ref mut h) => h.load_hpack(src, max_header_list_size, decoder),
394 Continuable::PushPromise(ref mut p) => p.load_hpack(src, max_header_list_size, decoder),
395 }
396 }
397 }
398
399 impl<T> From<Continuable> for Frame<T> {
from(cont: Continuable) -> Self400 fn from(cont: Continuable) -> Self {
401 match cont {
402 Continuable::Headers(mut headers) => {
403 headers.set_end_headers();
404 headers.into()
405 }
406 Continuable::PushPromise(mut push) => {
407 push.set_end_headers();
408 push.into()
409 }
410 }
411 }
412 }
413