1 use super::{util, StreamDependency, StreamId};
2 use crate::frame::{Error, Frame, Head, Kind};
3 use crate::hpack::{self, BytesStr};
4 
5 use http::header::{self, HeaderName, HeaderValue};
6 use http::{uri, HeaderMap, Method, Request, StatusCode, Uri};
7 
8 use bytes::BytesMut;
9 
10 use std::fmt;
11 use std::io::Cursor;
12 
13 type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>;
14 
15 // Minimum MAX_FRAME_SIZE is 16kb, so save some arbitrary space for frame
16 // head and other header bits.
17 const MAX_HEADER_LENGTH: usize = 1024 * 16 - 100;
18 
19 /// Header frame
20 ///
21 /// This could be either a request or a response.
22 #[derive(Eq, PartialEq)]
23 pub struct Headers {
24     /// The ID of the stream with which this frame is associated.
25     stream_id: StreamId,
26 
27     /// The stream dependency information, if any.
28     stream_dep: Option<StreamDependency>,
29 
30     /// The header block fragment
31     header_block: HeaderBlock,
32 
33     /// The associated flags
34     flags: HeadersFlag,
35 }
36 
37 #[derive(Copy, Clone, Eq, PartialEq)]
38 pub struct HeadersFlag(u8);
39 
40 #[derive(Eq, PartialEq)]
41 pub struct PushPromise {
42     /// The ID of the stream with which this frame is associated.
43     stream_id: StreamId,
44 
45     /// The ID of the stream being reserved by this PushPromise.
46     promised_id: StreamId,
47 
48     /// The header block fragment
49     header_block: HeaderBlock,
50 
51     /// The associated flags
52     flags: PushPromiseFlag,
53 }
54 
55 #[derive(Copy, Clone, Eq, PartialEq)]
56 pub struct PushPromiseFlag(u8);
57 
58 #[derive(Debug)]
59 pub struct Continuation {
60     /// Stream ID of continuation frame
61     stream_id: StreamId,
62 
63     header_block: EncodingHeaderBlock,
64 }
65 
66 // TODO: These fields shouldn't be `pub`
67 #[derive(Debug, Default, Eq, PartialEq)]
68 pub struct Pseudo {
69     // Request
70     pub method: Option<Method>,
71     pub scheme: Option<BytesStr>,
72     pub authority: Option<BytesStr>,
73     pub path: Option<BytesStr>,
74 
75     // Response
76     pub status: Option<StatusCode>,
77 }
78 
79 #[derive(Debug)]
80 pub struct Iter {
81     /// Pseudo headers
82     pseudo: Option<Pseudo>,
83 
84     /// Header fields
85     fields: header::IntoIter<HeaderValue>,
86 }
87 
88 #[derive(Debug, PartialEq, Eq)]
89 struct HeaderBlock {
90     /// The decoded header fields
91     fields: HeaderMap,
92 
93     /// Set to true if decoding went over the max header list size.
94     is_over_size: bool,
95 
96     /// Pseudo headers, these are broken out as they must be sent as part of the
97     /// headers frame.
98     pseudo: Pseudo,
99 }
100 
101 #[derive(Debug)]
102 struct EncodingHeaderBlock {
103     /// Argument to pass to the HPACK encoder to resume encoding
104     hpack: Option<hpack::EncodeState>,
105 
106     /// remaining headers to encode
107     headers: Iter,
108 }
109 
110 const END_STREAM: u8 = 0x1;
111 const END_HEADERS: u8 = 0x4;
112 const PADDED: u8 = 0x8;
113 const PRIORITY: u8 = 0x20;
114 const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY;
115 
116 // ===== impl Headers =====
117 
118 impl Headers {
119     /// Create a new HEADERS frame
new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self120     pub fn new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self {
121         Headers {
122             stream_id,
123             stream_dep: None,
124             header_block: HeaderBlock {
125                 fields,
126                 is_over_size: false,
127                 pseudo,
128             },
129             flags: HeadersFlag::default(),
130         }
131     }
132 
trailers(stream_id: StreamId, fields: HeaderMap) -> Self133     pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self {
134         let mut flags = HeadersFlag::default();
135         flags.set_end_stream();
136 
137         Headers {
138             stream_id,
139             stream_dep: None,
140             header_block: HeaderBlock {
141                 fields,
142                 is_over_size: false,
143                 pseudo: Pseudo::default(),
144             },
145             flags,
146         }
147     }
148 
149     /// Loads the header frame but doesn't actually do HPACK decoding.
150     ///
151     /// HPACK decoding is done in the `load_hpack` step.
load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error>152     pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
153         let flags = HeadersFlag(head.flag());
154         let mut pad = 0;
155 
156         tracing::trace!("loading headers; flags={:?}", flags);
157 
158         // Read the padding length
159         if flags.is_padded() {
160             if src.is_empty() {
161                 return Err(Error::MalformedMessage);
162             }
163             pad = src[0] as usize;
164 
165             // Drop the padding
166             let _ = src.split_to(1);
167         }
168 
169         // Read the stream dependency
170         let stream_dep = if flags.is_priority() {
171             if src.len() < 5 {
172                 return Err(Error::MalformedMessage);
173             }
174             let stream_dep = StreamDependency::load(&src[..5])?;
175 
176             if stream_dep.dependency_id() == head.stream_id() {
177                 return Err(Error::InvalidDependencyId);
178             }
179 
180             // Drop the next 5 bytes
181             let _ = src.split_to(5);
182 
183             Some(stream_dep)
184         } else {
185             None
186         };
187 
188         if pad > 0 {
189             if pad > src.len() {
190                 return Err(Error::TooMuchPadding);
191             }
192 
193             let len = src.len() - pad;
194             src.truncate(len);
195         }
196 
197         let headers = Headers {
198             stream_id: head.stream_id(),
199             stream_dep,
200             header_block: HeaderBlock {
201                 fields: HeaderMap::new(),
202                 is_over_size: false,
203                 pseudo: Pseudo::default(),
204             },
205             flags,
206         };
207 
208         Ok((headers, src))
209     }
210 
load_hpack( &mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder, ) -> Result<(), Error>211     pub fn load_hpack(
212         &mut self,
213         src: &mut BytesMut,
214         max_header_list_size: usize,
215         decoder: &mut hpack::Decoder,
216     ) -> Result<(), Error> {
217         self.header_block.load(src, max_header_list_size, decoder)
218     }
219 
stream_id(&self) -> StreamId220     pub fn stream_id(&self) -> StreamId {
221         self.stream_id
222     }
223 
is_end_headers(&self) -> bool224     pub fn is_end_headers(&self) -> bool {
225         self.flags.is_end_headers()
226     }
227 
set_end_headers(&mut self)228     pub fn set_end_headers(&mut self) {
229         self.flags.set_end_headers();
230     }
231 
is_end_stream(&self) -> bool232     pub fn is_end_stream(&self) -> bool {
233         self.flags.is_end_stream()
234     }
235 
set_end_stream(&mut self)236     pub fn set_end_stream(&mut self) {
237         self.flags.set_end_stream()
238     }
239 
is_over_size(&self) -> bool240     pub fn is_over_size(&self) -> bool {
241         self.header_block.is_over_size
242     }
243 
has_too_big_field(&self) -> bool244     pub(crate) fn has_too_big_field(&self) -> bool {
245         self.header_block.has_too_big_field()
246     }
247 
into_parts(self) -> (Pseudo, HeaderMap)248     pub fn into_parts(self) -> (Pseudo, HeaderMap) {
249         (self.header_block.pseudo, self.header_block.fields)
250     }
251 
252     #[cfg(feature = "unstable")]
pseudo_mut(&mut self) -> &mut Pseudo253     pub fn pseudo_mut(&mut self) -> &mut Pseudo {
254         &mut self.header_block.pseudo
255     }
256 
257     /// Whether it has status 1xx
is_informational(&self) -> bool258     pub(crate) fn is_informational(&self) -> bool {
259         self.header_block.pseudo.is_informational()
260     }
261 
fields(&self) -> &HeaderMap262     pub fn fields(&self) -> &HeaderMap {
263         &self.header_block.fields
264     }
265 
into_fields(self) -> HeaderMap266     pub fn into_fields(self) -> HeaderMap {
267         self.header_block.fields
268     }
269 
encode( self, encoder: &mut hpack::Encoder, dst: &mut EncodeBuf<'_>, ) -> Option<Continuation>270     pub fn encode(
271         self,
272         encoder: &mut hpack::Encoder,
273         dst: &mut EncodeBuf<'_>,
274     ) -> Option<Continuation> {
275         // At this point, the `is_end_headers` flag should always be set
276         debug_assert!(self.flags.is_end_headers());
277 
278         // Get the HEADERS frame head
279         let head = self.head();
280 
281         self.header_block
282             .into_encoding()
283             .encode(&head, encoder, dst, |_| {})
284     }
285 
head(&self) -> Head286     fn head(&self) -> Head {
287         Head::new(Kind::Headers, self.flags.into(), self.stream_id)
288     }
289 }
290 
291 impl<T> From<Headers> for Frame<T> {
from(src: Headers) -> Self292     fn from(src: Headers) -> Self {
293         Frame::Headers(src)
294     }
295 }
296 
297 impl fmt::Debug for Headers {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result298     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
299         let mut builder = f.debug_struct("Headers");
300         builder
301             .field("stream_id", &self.stream_id)
302             .field("flags", &self.flags);
303 
304         if let Some(ref dep) = self.stream_dep {
305             builder.field("stream_dep", dep);
306         }
307 
308         // `fields` and `pseudo` purposefully not included
309         builder.finish()
310     }
311 }
312 
313 // ===== util =====
314 
parse_u64(src: &[u8]) -> Result<u64, ()>315 pub fn parse_u64(src: &[u8]) -> Result<u64, ()> {
316     if src.len() > 19 {
317         // At danger for overflow...
318         return Err(());
319     }
320 
321     let mut ret = 0;
322 
323     for &d in src {
324         if d < b'0' || d > b'9' {
325             return Err(());
326         }
327 
328         ret *= 10;
329         ret += (d - b'0') as u64;
330     }
331 
332     Ok(ret)
333 }
334 
335 // ===== impl PushPromise =====
336 
337 #[derive(Debug)]
338 pub enum PushPromiseHeaderError {
339     InvalidContentLength(Result<u64, ()>),
340     NotSafeAndCacheable,
341 }
342 
343 impl PushPromise {
new( stream_id: StreamId, promised_id: StreamId, pseudo: Pseudo, fields: HeaderMap, ) -> Self344     pub fn new(
345         stream_id: StreamId,
346         promised_id: StreamId,
347         pseudo: Pseudo,
348         fields: HeaderMap,
349     ) -> Self {
350         PushPromise {
351             flags: PushPromiseFlag::default(),
352             header_block: HeaderBlock {
353                 fields,
354                 is_over_size: false,
355                 pseudo,
356             },
357             promised_id,
358             stream_id,
359         }
360     }
361 
validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError>362     pub fn validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError> {
363         use PushPromiseHeaderError::*;
364         // The spec has some requirements for promised request headers
365         // [https://httpwg.org/specs/rfc7540.html#PushRequests]
366 
367         // A promised request "that indicates the presence of a request body
368         // MUST reset the promised stream with a stream error"
369         if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) {
370             let parsed_length = parse_u64(content_length.as_bytes());
371             if parsed_length != Ok(0) {
372                 return Err(InvalidContentLength(parsed_length));
373             }
374         }
375         // "The server MUST include a method in the :method pseudo-header field
376         // that is safe and cacheable"
377         if !Self::safe_and_cacheable(req.method()) {
378             return Err(NotSafeAndCacheable);
379         }
380 
381         Ok(())
382     }
383 
safe_and_cacheable(method: &Method) -> bool384     fn safe_and_cacheable(method: &Method) -> bool {
385         // Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods
386         // Safe: https://httpwg.org/specs/rfc7231.html#safe.methods
387         return method == Method::GET || method == Method::HEAD;
388     }
389 
fields(&self) -> &HeaderMap390     pub fn fields(&self) -> &HeaderMap {
391         &self.header_block.fields
392     }
393 
394     #[cfg(feature = "unstable")]
into_fields(self) -> HeaderMap395     pub fn into_fields(self) -> HeaderMap {
396         self.header_block.fields
397     }
398 
399     /// Loads the push promise frame but doesn't actually do HPACK decoding.
400     ///
401     /// HPACK decoding is done in the `load_hpack` step.
load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error>402     pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
403         let flags = PushPromiseFlag(head.flag());
404         let mut pad = 0;
405 
406         // Read the padding length
407         if flags.is_padded() {
408             if src.is_empty() {
409                 return Err(Error::MalformedMessage);
410             }
411 
412             // TODO: Ensure payload is sized correctly
413             pad = src[0] as usize;
414 
415             // Drop the padding
416             let _ = src.split_to(1);
417         }
418 
419         if src.len() < 5 {
420             return Err(Error::MalformedMessage);
421         }
422 
423         let (promised_id, _) = StreamId::parse(&src[..4]);
424         // Drop promised_id bytes
425         let _ = src.split_to(4);
426 
427         if pad > 0 {
428             if pad > src.len() {
429                 return Err(Error::TooMuchPadding);
430             }
431 
432             let len = src.len() - pad;
433             src.truncate(len);
434         }
435 
436         let frame = PushPromise {
437             flags,
438             header_block: HeaderBlock {
439                 fields: HeaderMap::new(),
440                 is_over_size: false,
441                 pseudo: Pseudo::default(),
442             },
443             promised_id,
444             stream_id: head.stream_id(),
445         };
446         Ok((frame, src))
447     }
448 
load_hpack( &mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder, ) -> Result<(), Error>449     pub fn load_hpack(
450         &mut self,
451         src: &mut BytesMut,
452         max_header_list_size: usize,
453         decoder: &mut hpack::Decoder,
454     ) -> Result<(), Error> {
455         self.header_block.load(src, max_header_list_size, decoder)
456     }
457 
stream_id(&self) -> StreamId458     pub fn stream_id(&self) -> StreamId {
459         self.stream_id
460     }
461 
promised_id(&self) -> StreamId462     pub fn promised_id(&self) -> StreamId {
463         self.promised_id
464     }
465 
is_end_headers(&self) -> bool466     pub fn is_end_headers(&self) -> bool {
467         self.flags.is_end_headers()
468     }
469 
set_end_headers(&mut self)470     pub fn set_end_headers(&mut self) {
471         self.flags.set_end_headers();
472     }
473 
is_over_size(&self) -> bool474     pub fn is_over_size(&self) -> bool {
475         self.header_block.is_over_size
476     }
477 
encode( self, encoder: &mut hpack::Encoder, dst: &mut EncodeBuf<'_>, ) -> Option<Continuation>478     pub fn encode(
479         self,
480         encoder: &mut hpack::Encoder,
481         dst: &mut EncodeBuf<'_>,
482     ) -> Option<Continuation> {
483         use bytes::BufMut;
484 
485         // At this point, the `is_end_headers` flag should always be set
486         debug_assert!(self.flags.is_end_headers());
487 
488         let head = self.head();
489         let promised_id = self.promised_id;
490 
491         self.header_block
492             .into_encoding()
493             .encode(&head, encoder, dst, |dst| {
494                 dst.put_u32(promised_id.into());
495             })
496     }
497 
head(&self) -> Head498     fn head(&self) -> Head {
499         Head::new(Kind::PushPromise, self.flags.into(), self.stream_id)
500     }
501 
502     /// Consume `self`, returning the parts of the frame
into_parts(self) -> (Pseudo, HeaderMap)503     pub fn into_parts(self) -> (Pseudo, HeaderMap) {
504         (self.header_block.pseudo, self.header_block.fields)
505     }
506 }
507 
508 impl<T> From<PushPromise> for Frame<T> {
from(src: PushPromise) -> Self509     fn from(src: PushPromise) -> Self {
510         Frame::PushPromise(src)
511     }
512 }
513 
514 impl fmt::Debug for PushPromise {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result515     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
516         f.debug_struct("PushPromise")
517             .field("stream_id", &self.stream_id)
518             .field("promised_id", &self.promised_id)
519             .field("flags", &self.flags)
520             // `fields` and `pseudo` purposefully not included
521             .finish()
522     }
523 }
524 
525 // ===== impl Continuation =====
526 
527 impl Continuation {
head(&self) -> Head528     fn head(&self) -> Head {
529         Head::new(Kind::Continuation, END_HEADERS, self.stream_id)
530     }
531 
encode( self, encoder: &mut hpack::Encoder, dst: &mut EncodeBuf<'_>, ) -> Option<Continuation>532     pub fn encode(
533         self,
534         encoder: &mut hpack::Encoder,
535         dst: &mut EncodeBuf<'_>,
536     ) -> Option<Continuation> {
537         // Get the CONTINUATION frame head
538         let head = self.head();
539 
540         self.header_block.encode(&head, encoder, dst, |_| {})
541     }
542 }
543 
544 // ===== impl Pseudo =====
545 
546 impl Pseudo {
request(method: Method, uri: Uri) -> Self547     pub fn request(method: Method, uri: Uri) -> Self {
548         let parts = uri::Parts::from(uri);
549 
550         let mut path = parts
551             .path_and_query
552             .map(|v| BytesStr::from(v.as_str()))
553             .unwrap_or(BytesStr::from_static(""));
554 
555         match method {
556             Method::OPTIONS | Method::CONNECT => {}
557             _ if path.is_empty() => {
558                 path = BytesStr::from_static("/");
559             }
560             _ => {}
561         }
562 
563         let mut pseudo = Pseudo {
564             method: Some(method),
565             scheme: None,
566             authority: None,
567             path: Some(path).filter(|p| !p.is_empty()),
568             status: None,
569         };
570 
571         // If the URI includes a scheme component, add it to the pseudo headers
572         //
573         // TODO: Scheme must be set...
574         if let Some(scheme) = parts.scheme {
575             pseudo.set_scheme(scheme);
576         }
577 
578         // If the URI includes an authority component, add it to the pseudo
579         // headers
580         if let Some(authority) = parts.authority {
581             pseudo.set_authority(BytesStr::from(authority.as_str()));
582         }
583 
584         pseudo
585     }
586 
response(status: StatusCode) -> Self587     pub fn response(status: StatusCode) -> Self {
588         Pseudo {
589             method: None,
590             scheme: None,
591             authority: None,
592             path: None,
593             status: Some(status),
594         }
595     }
596 
set_scheme(&mut self, scheme: uri::Scheme)597     pub fn set_scheme(&mut self, scheme: uri::Scheme) {
598         let bytes_str = match scheme.as_str() {
599             "http" => BytesStr::from_static("http"),
600             "https" => BytesStr::from_static("https"),
601             s => BytesStr::from(s),
602         };
603         self.scheme = Some(bytes_str);
604     }
605 
set_authority(&mut self, authority: BytesStr)606     pub fn set_authority(&mut self, authority: BytesStr) {
607         self.authority = Some(authority);
608     }
609 
610     /// Whether it has status 1xx
is_informational(&self) -> bool611     pub(crate) fn is_informational(&self) -> bool {
612         self.status
613             .map_or(false, |status| status.is_informational())
614     }
615 }
616 
617 // ===== impl EncodingHeaderBlock =====
618 
619 impl EncodingHeaderBlock {
encode<F>( mut self, head: &Head, encoder: &mut hpack::Encoder, dst: &mut EncodeBuf<'_>, f: F, ) -> Option<Continuation> where F: FnOnce(&mut EncodeBuf<'_>),620     fn encode<F>(
621         mut self,
622         head: &Head,
623         encoder: &mut hpack::Encoder,
624         dst: &mut EncodeBuf<'_>,
625         f: F,
626     ) -> Option<Continuation>
627     where
628         F: FnOnce(&mut EncodeBuf<'_>),
629     {
630         let head_pos = dst.get_ref().len();
631 
632         // At this point, we don't know how big the h2 frame will be.
633         // So, we write the head with length 0, then write the body, and
634         // finally write the length once we know the size.
635         head.encode(0, dst);
636 
637         let payload_pos = dst.get_ref().len();
638 
639         f(dst);
640 
641         // Now, encode the header payload
642         let continuation = match encoder.encode(self.hpack, &mut self.headers, dst) {
643             hpack::Encode::Full => None,
644             hpack::Encode::Partial(state) => Some(Continuation {
645                 stream_id: head.stream_id(),
646                 header_block: EncodingHeaderBlock {
647                     hpack: Some(state),
648                     headers: self.headers,
649                 },
650             }),
651         };
652 
653         // Compute the header block length
654         let payload_len = (dst.get_ref().len() - payload_pos) as u64;
655 
656         // Write the frame length
657         let payload_len_be = payload_len.to_be_bytes();
658         assert!(payload_len_be[0..5].iter().all(|b| *b == 0));
659         (dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]);
660 
661         if continuation.is_some() {
662             // There will be continuation frames, so the `is_end_headers` flag
663             // must be unset
664             debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS);
665 
666             dst.get_mut()[head_pos + 4] -= END_HEADERS;
667         }
668 
669         continuation
670     }
671 }
672 
673 // ===== impl Iter =====
674 
675 impl Iterator for Iter {
676     type Item = hpack::Header<Option<HeaderName>>;
677 
next(&mut self) -> Option<Self::Item>678     fn next(&mut self) -> Option<Self::Item> {
679         use crate::hpack::Header::*;
680 
681         if let Some(ref mut pseudo) = self.pseudo {
682             if let Some(method) = pseudo.method.take() {
683                 return Some(Method(method));
684             }
685 
686             if let Some(scheme) = pseudo.scheme.take() {
687                 return Some(Scheme(scheme));
688             }
689 
690             if let Some(authority) = pseudo.authority.take() {
691                 return Some(Authority(authority));
692             }
693 
694             if let Some(path) = pseudo.path.take() {
695                 return Some(Path(path));
696             }
697 
698             if let Some(status) = pseudo.status.take() {
699                 return Some(Status(status));
700             }
701         }
702 
703         self.pseudo = None;
704 
705         self.fields
706             .next()
707             .map(|(name, value)| Field { name, value })
708     }
709 }
710 
711 // ===== impl HeadersFlag =====
712 
713 impl HeadersFlag {
empty() -> HeadersFlag714     pub fn empty() -> HeadersFlag {
715         HeadersFlag(0)
716     }
717 
load(bits: u8) -> HeadersFlag718     pub fn load(bits: u8) -> HeadersFlag {
719         HeadersFlag(bits & ALL)
720     }
721 
is_end_stream(&self) -> bool722     pub fn is_end_stream(&self) -> bool {
723         self.0 & END_STREAM == END_STREAM
724     }
725 
set_end_stream(&mut self)726     pub fn set_end_stream(&mut self) {
727         self.0 |= END_STREAM;
728     }
729 
is_end_headers(&self) -> bool730     pub fn is_end_headers(&self) -> bool {
731         self.0 & END_HEADERS == END_HEADERS
732     }
733 
set_end_headers(&mut self)734     pub fn set_end_headers(&mut self) {
735         self.0 |= END_HEADERS;
736     }
737 
is_padded(&self) -> bool738     pub fn is_padded(&self) -> bool {
739         self.0 & PADDED == PADDED
740     }
741 
is_priority(&self) -> bool742     pub fn is_priority(&self) -> bool {
743         self.0 & PRIORITY == PRIORITY
744     }
745 }
746 
747 impl Default for HeadersFlag {
748     /// Returns a `HeadersFlag` value with `END_HEADERS` set.
default() -> Self749     fn default() -> Self {
750         HeadersFlag(END_HEADERS)
751     }
752 }
753 
754 impl From<HeadersFlag> for u8 {
from(src: HeadersFlag) -> u8755     fn from(src: HeadersFlag) -> u8 {
756         src.0
757     }
758 }
759 
760 impl fmt::Debug for HeadersFlag {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result761     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
762         util::debug_flags(fmt, self.0)
763             .flag_if(self.is_end_headers(), "END_HEADERS")
764             .flag_if(self.is_end_stream(), "END_STREAM")
765             .flag_if(self.is_padded(), "PADDED")
766             .flag_if(self.is_priority(), "PRIORITY")
767             .finish()
768     }
769 }
770 
771 // ===== impl PushPromiseFlag =====
772 
773 impl PushPromiseFlag {
empty() -> PushPromiseFlag774     pub fn empty() -> PushPromiseFlag {
775         PushPromiseFlag(0)
776     }
777 
load(bits: u8) -> PushPromiseFlag778     pub fn load(bits: u8) -> PushPromiseFlag {
779         PushPromiseFlag(bits & ALL)
780     }
781 
is_end_headers(&self) -> bool782     pub fn is_end_headers(&self) -> bool {
783         self.0 & END_HEADERS == END_HEADERS
784     }
785 
set_end_headers(&mut self)786     pub fn set_end_headers(&mut self) {
787         self.0 |= END_HEADERS;
788     }
789 
is_padded(&self) -> bool790     pub fn is_padded(&self) -> bool {
791         self.0 & PADDED == PADDED
792     }
793 }
794 
795 impl Default for PushPromiseFlag {
796     /// Returns a `PushPromiseFlag` value with `END_HEADERS` set.
default() -> Self797     fn default() -> Self {
798         PushPromiseFlag(END_HEADERS)
799     }
800 }
801 
802 impl From<PushPromiseFlag> for u8 {
from(src: PushPromiseFlag) -> u8803     fn from(src: PushPromiseFlag) -> u8 {
804         src.0
805     }
806 }
807 
808 impl fmt::Debug for PushPromiseFlag {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result809     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
810         util::debug_flags(fmt, self.0)
811             .flag_if(self.is_end_headers(), "END_HEADERS")
812             .flag_if(self.is_padded(), "PADDED")
813             .finish()
814     }
815 }
816 
817 // ===== HeaderBlock =====
818 
819 impl HeaderBlock {
load( &mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder, ) -> Result<(), Error>820     fn load(
821         &mut self,
822         src: &mut BytesMut,
823         max_header_list_size: usize,
824         decoder: &mut hpack::Decoder,
825     ) -> Result<(), Error> {
826         let mut reg = !self.fields.is_empty();
827         let mut malformed = false;
828         let mut headers_size = self.calculate_header_list_size();
829 
830         macro_rules! set_pseudo {
831             ($field:ident, $val:expr) => {{
832                 if reg {
833                     tracing::trace!("load_hpack; header malformed -- pseudo not at head of block");
834                     malformed = true;
835                 } else if self.pseudo.$field.is_some() {
836                     tracing::trace!("load_hpack; header malformed -- repeated pseudo");
837                     malformed = true;
838                 } else {
839                     let __val = $val;
840                     headers_size +=
841                         decoded_header_size(stringify!($field).len() + 1, __val.as_str().len());
842                     if headers_size < max_header_list_size {
843                         self.pseudo.$field = Some(__val);
844                     } else if !self.is_over_size {
845                         tracing::trace!("load_hpack; header list size over max");
846                         self.is_over_size = true;
847                     }
848                 }
849             }};
850         }
851 
852         let mut cursor = Cursor::new(src);
853 
854         // If the header frame is malformed, we still have to continue decoding
855         // the headers. A malformed header frame is a stream level error, but
856         // the hpack state is connection level. In order to maintain correct
857         // state for other streams, the hpack decoding process must complete.
858         let res = decoder.decode(&mut cursor, |header| {
859             use crate::hpack::Header::*;
860 
861             match header {
862                 Field { name, value } => {
863                     // Connection level header fields are not supported and must
864                     // result in a protocol error.
865 
866                     if name == header::CONNECTION
867                         || name == header::TRANSFER_ENCODING
868                         || name == header::UPGRADE
869                         || name == "keep-alive"
870                         || name == "proxy-connection"
871                     {
872                         tracing::trace!("load_hpack; connection level header");
873                         malformed = true;
874                     } else if name == header::TE && value != "trailers" {
875                         tracing::trace!(
876                             "load_hpack; TE header not set to trailers; val={:?}",
877                             value
878                         );
879                         malformed = true;
880                     } else {
881                         reg = true;
882 
883                         headers_size += decoded_header_size(name.as_str().len(), value.len());
884                         if headers_size < max_header_list_size {
885                             self.fields.append(name, value);
886                         } else if !self.is_over_size {
887                             tracing::trace!("load_hpack; header list size over max");
888                             self.is_over_size = true;
889                         }
890                     }
891                 }
892                 Authority(v) => set_pseudo!(authority, v),
893                 Method(v) => set_pseudo!(method, v),
894                 Scheme(v) => set_pseudo!(scheme, v),
895                 Path(v) => set_pseudo!(path, v),
896                 Status(v) => set_pseudo!(status, v),
897             }
898         });
899 
900         if let Err(e) = res {
901             tracing::trace!("hpack decoding error; err={:?}", e);
902             return Err(e.into());
903         }
904 
905         if malformed {
906             tracing::trace!("malformed message");
907             return Err(Error::MalformedMessage);
908         }
909 
910         Ok(())
911     }
912 
into_encoding(self) -> EncodingHeaderBlock913     fn into_encoding(self) -> EncodingHeaderBlock {
914         EncodingHeaderBlock {
915             hpack: None,
916             headers: Iter {
917                 pseudo: Some(self.pseudo),
918                 fields: self.fields.into_iter(),
919             },
920         }
921     }
922 
923     /// Calculates the size of the currently decoded header list.
924     ///
925     /// According to http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE
926     ///
927     /// > The value is based on the uncompressed size of header fields,
928     /// > including the length of the name and value in octets plus an
929     /// > overhead of 32 octets for each header field.
calculate_header_list_size(&self) -> usize930     fn calculate_header_list_size(&self) -> usize {
931         macro_rules! pseudo_size {
932             ($name:ident) => {{
933                 self.pseudo
934                     .$name
935                     .as_ref()
936                     .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
937                     .unwrap_or(0)
938             }};
939         }
940 
941         pseudo_size!(method)
942             + pseudo_size!(scheme)
943             + pseudo_size!(status)
944             + pseudo_size!(authority)
945             + pseudo_size!(path)
946             + self
947                 .fields
948                 .iter()
949                 .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len()))
950                 .sum::<usize>()
951     }
952 
953     /// Iterate over all pseudos and headers to see if any individual pair
954     /// would be too large to encode.
has_too_big_field(&self) -> bool955     pub(crate) fn has_too_big_field(&self) -> bool {
956         macro_rules! pseudo_size {
957             ($name:ident) => {{
958                 self.pseudo
959                     .$name
960                     .as_ref()
961                     .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
962                     .unwrap_or(0)
963             }};
964         }
965 
966         if pseudo_size!(method) > MAX_HEADER_LENGTH {
967             return true;
968         }
969 
970         if pseudo_size!(scheme) > MAX_HEADER_LENGTH {
971             return true;
972         }
973 
974         if pseudo_size!(authority) > MAX_HEADER_LENGTH {
975             return true;
976         }
977 
978         if pseudo_size!(path) > MAX_HEADER_LENGTH {
979             return true;
980         }
981 
982         // skip :status, its never going to be too big
983 
984         for (name, value) in &self.fields {
985             if decoded_header_size(name.as_str().len(), value.len()) > MAX_HEADER_LENGTH {
986                 return true;
987             }
988         }
989 
990         false
991     }
992 }
993 
decoded_header_size(name: usize, value: usize) -> usize994 fn decoded_header_size(name: usize, value: usize) -> usize {
995     name + value + 32
996 }
997