1 use super::{util, StreamDependency, StreamId};
2 use crate::ext::Protocol;
3 use crate::frame::{Error, Frame, Head, Kind};
4 use crate::hpack::{self, BytesStr};
5 
6 use http::header::{self, HeaderName, HeaderValue};
7 use http::{uri, HeaderMap, Method, Request, StatusCode, Uri};
8 
9 use bytes::{BufMut, Bytes, BytesMut};
10 
11 use std::fmt;
12 use std::io::Cursor;
13 
14 type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>;
15 /// Header frame
16 ///
17 /// This could be either a request or a response.
18 #[derive(Eq, PartialEq)]
19 pub struct Headers {
20     /// The ID of the stream with which this frame is associated.
21     stream_id: StreamId,
22 
23     /// The stream dependency information, if any.
24     stream_dep: Option<StreamDependency>,
25 
26     /// The header block fragment
27     header_block: HeaderBlock,
28 
29     /// The associated flags
30     flags: HeadersFlag,
31 }
32 
33 #[derive(Copy, Clone, Eq, PartialEq)]
34 pub struct HeadersFlag(u8);
35 
36 #[derive(Eq, PartialEq)]
37 pub struct PushPromise {
38     /// The ID of the stream with which this frame is associated.
39     stream_id: StreamId,
40 
41     /// The ID of the stream being reserved by this PushPromise.
42     promised_id: StreamId,
43 
44     /// The header block fragment
45     header_block: HeaderBlock,
46 
47     /// The associated flags
48     flags: PushPromiseFlag,
49 }
50 
51 #[derive(Copy, Clone, Eq, PartialEq)]
52 pub struct PushPromiseFlag(u8);
53 
54 #[derive(Debug)]
55 pub struct Continuation {
56     /// Stream ID of continuation frame
57     stream_id: StreamId,
58 
59     header_block: EncodingHeaderBlock,
60 }
61 
62 // TODO: These fields shouldn't be `pub`
63 #[derive(Debug, Default, Eq, PartialEq)]
64 pub struct Pseudo {
65     // Request
66     pub method: Option<Method>,
67     pub scheme: Option<BytesStr>,
68     pub authority: Option<BytesStr>,
69     pub path: Option<BytesStr>,
70     pub protocol: Option<Protocol>,
71 
72     // Response
73     pub status: Option<StatusCode>,
74 }
75 
76 #[derive(Debug)]
77 pub struct Iter {
78     /// Pseudo headers
79     pseudo: Option<Pseudo>,
80 
81     /// Header fields
82     fields: header::IntoIter<HeaderValue>,
83 }
84 
85 #[derive(Debug, PartialEq, Eq)]
86 struct HeaderBlock {
87     /// The decoded header fields
88     fields: HeaderMap,
89 
90     /// Set to true if decoding went over the max header list size.
91     is_over_size: bool,
92 
93     /// Pseudo headers, these are broken out as they must be sent as part of the
94     /// headers frame.
95     pseudo: Pseudo,
96 }
97 
98 #[derive(Debug)]
99 struct EncodingHeaderBlock {
100     hpack: Bytes,
101 }
102 
103 const END_STREAM: u8 = 0x1;
104 const END_HEADERS: u8 = 0x4;
105 const PADDED: u8 = 0x8;
106 const PRIORITY: u8 = 0x20;
107 const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY;
108 
109 // ===== impl Headers =====
110 
111 impl Headers {
112     /// Create a new HEADERS frame
new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self113     pub fn new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self {
114         Headers {
115             stream_id,
116             stream_dep: None,
117             header_block: HeaderBlock {
118                 fields,
119                 is_over_size: false,
120                 pseudo,
121             },
122             flags: HeadersFlag::default(),
123         }
124     }
125 
trailers(stream_id: StreamId, fields: HeaderMap) -> Self126     pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self {
127         let mut flags = HeadersFlag::default();
128         flags.set_end_stream();
129 
130         Headers {
131             stream_id,
132             stream_dep: None,
133             header_block: HeaderBlock {
134                 fields,
135                 is_over_size: false,
136                 pseudo: Pseudo::default(),
137             },
138             flags,
139         }
140     }
141 
142     /// Loads the header frame but doesn't actually do HPACK decoding.
143     ///
144     /// HPACK decoding is done in the `load_hpack` step.
load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error>145     pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
146         let flags = HeadersFlag(head.flag());
147         let mut pad = 0;
148 
149         tracing::trace!("loading headers; flags={:?}", flags);
150 
151         if head.stream_id().is_zero() {
152             return Err(Error::InvalidStreamId);
153         }
154 
155         // Read the padding length
156         if flags.is_padded() {
157             if src.is_empty() {
158                 return Err(Error::MalformedMessage);
159             }
160             pad = src[0] as usize;
161 
162             // Drop the padding
163             let _ = src.split_to(1);
164         }
165 
166         // Read the stream dependency
167         let stream_dep = if flags.is_priority() {
168             if src.len() < 5 {
169                 return Err(Error::MalformedMessage);
170             }
171             let stream_dep = StreamDependency::load(&src[..5])?;
172 
173             if stream_dep.dependency_id() == head.stream_id() {
174                 return Err(Error::InvalidDependencyId);
175             }
176 
177             // Drop the next 5 bytes
178             let _ = src.split_to(5);
179 
180             Some(stream_dep)
181         } else {
182             None
183         };
184 
185         if pad > 0 {
186             if pad > src.len() {
187                 return Err(Error::TooMuchPadding);
188             }
189 
190             let len = src.len() - pad;
191             src.truncate(len);
192         }
193 
194         let headers = Headers {
195             stream_id: head.stream_id(),
196             stream_dep,
197             header_block: HeaderBlock {
198                 fields: HeaderMap::new(),
199                 is_over_size: false,
200                 pseudo: Pseudo::default(),
201             },
202             flags,
203         };
204 
205         Ok((headers, src))
206     }
207 
load_hpack( &mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder, ) -> Result<(), Error>208     pub fn load_hpack(
209         &mut self,
210         src: &mut BytesMut,
211         max_header_list_size: usize,
212         decoder: &mut hpack::Decoder,
213     ) -> Result<(), Error> {
214         self.header_block.load(src, max_header_list_size, decoder)
215     }
216 
stream_id(&self) -> StreamId217     pub fn stream_id(&self) -> StreamId {
218         self.stream_id
219     }
220 
is_end_headers(&self) -> bool221     pub fn is_end_headers(&self) -> bool {
222         self.flags.is_end_headers()
223     }
224 
set_end_headers(&mut self)225     pub fn set_end_headers(&mut self) {
226         self.flags.set_end_headers();
227     }
228 
is_end_stream(&self) -> bool229     pub fn is_end_stream(&self) -> bool {
230         self.flags.is_end_stream()
231     }
232 
set_end_stream(&mut self)233     pub fn set_end_stream(&mut self) {
234         self.flags.set_end_stream()
235     }
236 
is_over_size(&self) -> bool237     pub fn is_over_size(&self) -> bool {
238         self.header_block.is_over_size
239     }
240 
into_parts(self) -> (Pseudo, HeaderMap)241     pub fn into_parts(self) -> (Pseudo, HeaderMap) {
242         (self.header_block.pseudo, self.header_block.fields)
243     }
244 
245     #[cfg(feature = "unstable")]
pseudo_mut(&mut self) -> &mut Pseudo246     pub fn pseudo_mut(&mut self) -> &mut Pseudo {
247         &mut self.header_block.pseudo
248     }
249 
250     /// Whether it has status 1xx
is_informational(&self) -> bool251     pub(crate) fn is_informational(&self) -> bool {
252         self.header_block.pseudo.is_informational()
253     }
254 
fields(&self) -> &HeaderMap255     pub fn fields(&self) -> &HeaderMap {
256         &self.header_block.fields
257     }
258 
into_fields(self) -> HeaderMap259     pub fn into_fields(self) -> HeaderMap {
260         self.header_block.fields
261     }
262 
encode( self, encoder: &mut hpack::Encoder, dst: &mut EncodeBuf<'_>, ) -> Option<Continuation>263     pub fn encode(
264         self,
265         encoder: &mut hpack::Encoder,
266         dst: &mut EncodeBuf<'_>,
267     ) -> Option<Continuation> {
268         // At this point, the `is_end_headers` flag should always be set
269         debug_assert!(self.flags.is_end_headers());
270 
271         // Get the HEADERS frame head
272         let head = self.head();
273 
274         self.header_block
275             .into_encoding(encoder)
276             .encode(&head, dst, |_| {})
277     }
278 
head(&self) -> Head279     fn head(&self) -> Head {
280         Head::new(Kind::Headers, self.flags.into(), self.stream_id)
281     }
282 }
283 
284 impl<T> From<Headers> for Frame<T> {
from(src: Headers) -> Self285     fn from(src: Headers) -> Self {
286         Frame::Headers(src)
287     }
288 }
289 
290 impl fmt::Debug for Headers {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result291     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
292         let mut builder = f.debug_struct("Headers");
293         builder
294             .field("stream_id", &self.stream_id)
295             .field("flags", &self.flags);
296 
297         if let Some(ref protocol) = self.header_block.pseudo.protocol {
298             builder.field("protocol", protocol);
299         }
300 
301         if let Some(ref dep) = self.stream_dep {
302             builder.field("stream_dep", dep);
303         }
304 
305         // `fields` and `pseudo` purposefully not included
306         builder.finish()
307     }
308 }
309 
310 // ===== util =====
311 
parse_u64(src: &[u8]) -> Result<u64, ()>312 pub fn parse_u64(src: &[u8]) -> Result<u64, ()> {
313     if src.len() > 19 {
314         // At danger for overflow...
315         return Err(());
316     }
317 
318     let mut ret = 0;
319 
320     for &d in src {
321         if d < b'0' || d > b'9' {
322             return Err(());
323         }
324 
325         ret *= 10;
326         ret += (d - b'0') as u64;
327     }
328 
329     Ok(ret)
330 }
331 
332 // ===== impl PushPromise =====
333 
334 #[derive(Debug)]
335 pub enum PushPromiseHeaderError {
336     InvalidContentLength(Result<u64, ()>),
337     NotSafeAndCacheable,
338 }
339 
340 impl PushPromise {
new( stream_id: StreamId, promised_id: StreamId, pseudo: Pseudo, fields: HeaderMap, ) -> Self341     pub fn new(
342         stream_id: StreamId,
343         promised_id: StreamId,
344         pseudo: Pseudo,
345         fields: HeaderMap,
346     ) -> Self {
347         PushPromise {
348             flags: PushPromiseFlag::default(),
349             header_block: HeaderBlock {
350                 fields,
351                 is_over_size: false,
352                 pseudo,
353             },
354             promised_id,
355             stream_id,
356         }
357     }
358 
validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError>359     pub fn validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError> {
360         use PushPromiseHeaderError::*;
361         // The spec has some requirements for promised request headers
362         // [https://httpwg.org/specs/rfc7540.html#PushRequests]
363 
364         // A promised request "that indicates the presence of a request body
365         // MUST reset the promised stream with a stream error"
366         if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) {
367             let parsed_length = parse_u64(content_length.as_bytes());
368             if parsed_length != Ok(0) {
369                 return Err(InvalidContentLength(parsed_length));
370             }
371         }
372         // "The server MUST include a method in the :method pseudo-header field
373         // that is safe and cacheable"
374         if !Self::safe_and_cacheable(req.method()) {
375             return Err(NotSafeAndCacheable);
376         }
377 
378         Ok(())
379     }
380 
safe_and_cacheable(method: &Method) -> bool381     fn safe_and_cacheable(method: &Method) -> bool {
382         // Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods
383         // Safe: https://httpwg.org/specs/rfc7231.html#safe.methods
384         return method == Method::GET || method == Method::HEAD;
385     }
386 
fields(&self) -> &HeaderMap387     pub fn fields(&self) -> &HeaderMap {
388         &self.header_block.fields
389     }
390 
391     #[cfg(feature = "unstable")]
into_fields(self) -> HeaderMap392     pub fn into_fields(self) -> HeaderMap {
393         self.header_block.fields
394     }
395 
396     /// Loads the push promise frame but doesn't actually do HPACK decoding.
397     ///
398     /// HPACK decoding is done in the `load_hpack` step.
load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error>399     pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
400         let flags = PushPromiseFlag(head.flag());
401         let mut pad = 0;
402 
403         // Read the padding length
404         if flags.is_padded() {
405             if src.is_empty() {
406                 return Err(Error::MalformedMessage);
407             }
408 
409             // TODO: Ensure payload is sized correctly
410             pad = src[0] as usize;
411 
412             // Drop the padding
413             let _ = src.split_to(1);
414         }
415 
416         if src.len() < 5 {
417             return Err(Error::MalformedMessage);
418         }
419 
420         let (promised_id, _) = StreamId::parse(&src[..4]);
421         // Drop promised_id bytes
422         let _ = src.split_to(4);
423 
424         if pad > 0 {
425             if pad > src.len() {
426                 return Err(Error::TooMuchPadding);
427             }
428 
429             let len = src.len() - pad;
430             src.truncate(len);
431         }
432 
433         let frame = PushPromise {
434             flags,
435             header_block: HeaderBlock {
436                 fields: HeaderMap::new(),
437                 is_over_size: false,
438                 pseudo: Pseudo::default(),
439             },
440             promised_id,
441             stream_id: head.stream_id(),
442         };
443         Ok((frame, src))
444     }
445 
load_hpack( &mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder, ) -> Result<(), Error>446     pub fn load_hpack(
447         &mut self,
448         src: &mut BytesMut,
449         max_header_list_size: usize,
450         decoder: &mut hpack::Decoder,
451     ) -> Result<(), Error> {
452         self.header_block.load(src, max_header_list_size, decoder)
453     }
454 
stream_id(&self) -> StreamId455     pub fn stream_id(&self) -> StreamId {
456         self.stream_id
457     }
458 
promised_id(&self) -> StreamId459     pub fn promised_id(&self) -> StreamId {
460         self.promised_id
461     }
462 
is_end_headers(&self) -> bool463     pub fn is_end_headers(&self) -> bool {
464         self.flags.is_end_headers()
465     }
466 
set_end_headers(&mut self)467     pub fn set_end_headers(&mut self) {
468         self.flags.set_end_headers();
469     }
470 
is_over_size(&self) -> bool471     pub fn is_over_size(&self) -> bool {
472         self.header_block.is_over_size
473     }
474 
encode( self, encoder: &mut hpack::Encoder, dst: &mut EncodeBuf<'_>, ) -> Option<Continuation>475     pub fn encode(
476         self,
477         encoder: &mut hpack::Encoder,
478         dst: &mut EncodeBuf<'_>,
479     ) -> Option<Continuation> {
480         // At this point, the `is_end_headers` flag should always be set
481         debug_assert!(self.flags.is_end_headers());
482 
483         let head = self.head();
484         let promised_id = self.promised_id;
485 
486         self.header_block
487             .into_encoding(encoder)
488             .encode(&head, dst, |dst| {
489                 dst.put_u32(promised_id.into());
490             })
491     }
492 
head(&self) -> Head493     fn head(&self) -> Head {
494         Head::new(Kind::PushPromise, self.flags.into(), self.stream_id)
495     }
496 
497     /// Consume `self`, returning the parts of the frame
into_parts(self) -> (Pseudo, HeaderMap)498     pub fn into_parts(self) -> (Pseudo, HeaderMap) {
499         (self.header_block.pseudo, self.header_block.fields)
500     }
501 }
502 
503 impl<T> From<PushPromise> for Frame<T> {
from(src: PushPromise) -> Self504     fn from(src: PushPromise) -> Self {
505         Frame::PushPromise(src)
506     }
507 }
508 
509 impl fmt::Debug for PushPromise {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result510     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
511         f.debug_struct("PushPromise")
512             .field("stream_id", &self.stream_id)
513             .field("promised_id", &self.promised_id)
514             .field("flags", &self.flags)
515             // `fields` and `pseudo` purposefully not included
516             .finish()
517     }
518 }
519 
520 // ===== impl Continuation =====
521 
522 impl Continuation {
head(&self) -> Head523     fn head(&self) -> Head {
524         Head::new(Kind::Continuation, END_HEADERS, self.stream_id)
525     }
526 
encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation>527     pub fn encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation> {
528         // Get the CONTINUATION frame head
529         let head = self.head();
530 
531         self.header_block.encode(&head, dst, |_| {})
532     }
533 }
534 
535 // ===== impl Pseudo =====
536 
537 impl Pseudo {
request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self538     pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self {
539         let parts = uri::Parts::from(uri);
540 
541         let mut path = parts
542             .path_and_query
543             .map(|v| BytesStr::from(v.as_str()))
544             .unwrap_or(BytesStr::from_static(""));
545 
546         match method {
547             Method::OPTIONS | Method::CONNECT => {}
548             _ if path.is_empty() => {
549                 path = BytesStr::from_static("/");
550             }
551             _ => {}
552         }
553 
554         let mut pseudo = Pseudo {
555             method: Some(method),
556             scheme: None,
557             authority: None,
558             path: Some(path).filter(|p| !p.is_empty()),
559             protocol,
560             status: None,
561         };
562 
563         // If the URI includes a scheme component, add it to the pseudo headers
564         //
565         // TODO: Scheme must be set...
566         if let Some(scheme) = parts.scheme {
567             pseudo.set_scheme(scheme);
568         }
569 
570         // If the URI includes an authority component, add it to the pseudo
571         // headers
572         if let Some(authority) = parts.authority {
573             pseudo.set_authority(BytesStr::from(authority.as_str()));
574         }
575 
576         pseudo
577     }
578 
response(status: StatusCode) -> Self579     pub fn response(status: StatusCode) -> Self {
580         Pseudo {
581             method: None,
582             scheme: None,
583             authority: None,
584             path: None,
585             protocol: None,
586             status: Some(status),
587         }
588     }
589 
590     #[cfg(feature = "unstable")]
set_status(&mut self, value: StatusCode)591     pub fn set_status(&mut self, value: StatusCode) {
592         self.status = Some(value);
593     }
594 
set_scheme(&mut self, scheme: uri::Scheme)595     pub fn set_scheme(&mut self, scheme: uri::Scheme) {
596         let bytes_str = match scheme.as_str() {
597             "http" => BytesStr::from_static("http"),
598             "https" => BytesStr::from_static("https"),
599             s => BytesStr::from(s),
600         };
601         self.scheme = Some(bytes_str);
602     }
603 
604     #[cfg(feature = "unstable")]
set_protocol(&mut self, protocol: Protocol)605     pub fn set_protocol(&mut self, protocol: Protocol) {
606         self.protocol = Some(protocol);
607     }
608 
set_authority(&mut self, authority: BytesStr)609     pub fn set_authority(&mut self, authority: BytesStr) {
610         self.authority = Some(authority);
611     }
612 
613     /// Whether it has status 1xx
is_informational(&self) -> bool614     pub(crate) fn is_informational(&self) -> bool {
615         self.status
616             .map_or(false, |status| status.is_informational())
617     }
618 }
619 
620 // ===== impl EncodingHeaderBlock =====
621 
622 impl EncodingHeaderBlock {
encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation> where F: FnOnce(&mut EncodeBuf<'_>),623     fn encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation>
624     where
625         F: FnOnce(&mut EncodeBuf<'_>),
626     {
627         let head_pos = dst.get_ref().len();
628 
629         // At this point, we don't know how big the h2 frame will be.
630         // So, we write the head with length 0, then write the body, and
631         // finally write the length once we know the size.
632         head.encode(0, dst);
633 
634         let payload_pos = dst.get_ref().len();
635 
636         f(dst);
637 
638         // Now, encode the header payload
639         let continuation = if self.hpack.len() > dst.remaining_mut() {
640             dst.put_slice(&self.hpack.split_to(dst.remaining_mut()));
641 
642             Some(Continuation {
643                 stream_id: head.stream_id(),
644                 header_block: self,
645             })
646         } else {
647             dst.put_slice(&self.hpack);
648 
649             None
650         };
651 
652         // Compute the header block length
653         let payload_len = (dst.get_ref().len() - payload_pos) as u64;
654 
655         // Write the frame length
656         let payload_len_be = payload_len.to_be_bytes();
657         assert!(payload_len_be[0..5].iter().all(|b| *b == 0));
658         (dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]);
659 
660         if continuation.is_some() {
661             // There will be continuation frames, so the `is_end_headers` flag
662             // must be unset
663             debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS);
664 
665             dst.get_mut()[head_pos + 4] -= END_HEADERS;
666         }
667 
668         continuation
669     }
670 }
671 
672 // ===== impl Iter =====
673 
674 impl Iterator for Iter {
675     type Item = hpack::Header<Option<HeaderName>>;
676 
next(&mut self) -> Option<Self::Item>677     fn next(&mut self) -> Option<Self::Item> {
678         use crate::hpack::Header::*;
679 
680         if let Some(ref mut pseudo) = self.pseudo {
681             if let Some(method) = pseudo.method.take() {
682                 return Some(Method(method));
683             }
684 
685             if let Some(scheme) = pseudo.scheme.take() {
686                 return Some(Scheme(scheme));
687             }
688 
689             if let Some(authority) = pseudo.authority.take() {
690                 return Some(Authority(authority));
691             }
692 
693             if let Some(path) = pseudo.path.take() {
694                 return Some(Path(path));
695             }
696 
697             if let Some(protocol) = pseudo.protocol.take() {
698                 return Some(Protocol(protocol));
699             }
700 
701             if let Some(status) = pseudo.status.take() {
702                 return Some(Status(status));
703             }
704         }
705 
706         self.pseudo = None;
707 
708         self.fields
709             .next()
710             .map(|(name, value)| Field { name, value })
711     }
712 }
713 
714 // ===== impl HeadersFlag =====
715 
716 impl HeadersFlag {
empty() -> HeadersFlag717     pub fn empty() -> HeadersFlag {
718         HeadersFlag(0)
719     }
720 
load(bits: u8) -> HeadersFlag721     pub fn load(bits: u8) -> HeadersFlag {
722         HeadersFlag(bits & ALL)
723     }
724 
is_end_stream(&self) -> bool725     pub fn is_end_stream(&self) -> bool {
726         self.0 & END_STREAM == END_STREAM
727     }
728 
set_end_stream(&mut self)729     pub fn set_end_stream(&mut self) {
730         self.0 |= END_STREAM;
731     }
732 
is_end_headers(&self) -> bool733     pub fn is_end_headers(&self) -> bool {
734         self.0 & END_HEADERS == END_HEADERS
735     }
736 
set_end_headers(&mut self)737     pub fn set_end_headers(&mut self) {
738         self.0 |= END_HEADERS;
739     }
740 
is_padded(&self) -> bool741     pub fn is_padded(&self) -> bool {
742         self.0 & PADDED == PADDED
743     }
744 
is_priority(&self) -> bool745     pub fn is_priority(&self) -> bool {
746         self.0 & PRIORITY == PRIORITY
747     }
748 }
749 
750 impl Default for HeadersFlag {
751     /// Returns a `HeadersFlag` value with `END_HEADERS` set.
default() -> Self752     fn default() -> Self {
753         HeadersFlag(END_HEADERS)
754     }
755 }
756 
757 impl From<HeadersFlag> for u8 {
from(src: HeadersFlag) -> u8758     fn from(src: HeadersFlag) -> u8 {
759         src.0
760     }
761 }
762 
763 impl fmt::Debug for HeadersFlag {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result764     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
765         util::debug_flags(fmt, self.0)
766             .flag_if(self.is_end_headers(), "END_HEADERS")
767             .flag_if(self.is_end_stream(), "END_STREAM")
768             .flag_if(self.is_padded(), "PADDED")
769             .flag_if(self.is_priority(), "PRIORITY")
770             .finish()
771     }
772 }
773 
774 // ===== impl PushPromiseFlag =====
775 
776 impl PushPromiseFlag {
empty() -> PushPromiseFlag777     pub fn empty() -> PushPromiseFlag {
778         PushPromiseFlag(0)
779     }
780 
load(bits: u8) -> PushPromiseFlag781     pub fn load(bits: u8) -> PushPromiseFlag {
782         PushPromiseFlag(bits & ALL)
783     }
784 
is_end_headers(&self) -> bool785     pub fn is_end_headers(&self) -> bool {
786         self.0 & END_HEADERS == END_HEADERS
787     }
788 
set_end_headers(&mut self)789     pub fn set_end_headers(&mut self) {
790         self.0 |= END_HEADERS;
791     }
792 
is_padded(&self) -> bool793     pub fn is_padded(&self) -> bool {
794         self.0 & PADDED == PADDED
795     }
796 }
797 
798 impl Default for PushPromiseFlag {
799     /// Returns a `PushPromiseFlag` value with `END_HEADERS` set.
default() -> Self800     fn default() -> Self {
801         PushPromiseFlag(END_HEADERS)
802     }
803 }
804 
805 impl From<PushPromiseFlag> for u8 {
from(src: PushPromiseFlag) -> u8806     fn from(src: PushPromiseFlag) -> u8 {
807         src.0
808     }
809 }
810 
811 impl fmt::Debug for PushPromiseFlag {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result812     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
813         util::debug_flags(fmt, self.0)
814             .flag_if(self.is_end_headers(), "END_HEADERS")
815             .flag_if(self.is_padded(), "PADDED")
816             .finish()
817     }
818 }
819 
820 // ===== HeaderBlock =====
821 
822 impl HeaderBlock {
load( &mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder, ) -> Result<(), Error>823     fn load(
824         &mut self,
825         src: &mut BytesMut,
826         max_header_list_size: usize,
827         decoder: &mut hpack::Decoder,
828     ) -> Result<(), Error> {
829         let mut reg = !self.fields.is_empty();
830         let mut malformed = false;
831         let mut headers_size = self.calculate_header_list_size();
832 
833         macro_rules! set_pseudo {
834             ($field:ident, $val:expr) => {{
835                 if reg {
836                     tracing::trace!("load_hpack; header malformed -- pseudo not at head of block");
837                     malformed = true;
838                 } else if self.pseudo.$field.is_some() {
839                     tracing::trace!("load_hpack; header malformed -- repeated pseudo");
840                     malformed = true;
841                 } else {
842                     let __val = $val;
843                     headers_size +=
844                         decoded_header_size(stringify!($field).len() + 1, __val.as_str().len());
845                     if headers_size < max_header_list_size {
846                         self.pseudo.$field = Some(__val);
847                     } else if !self.is_over_size {
848                         tracing::trace!("load_hpack; header list size over max");
849                         self.is_over_size = true;
850                     }
851                 }
852             }};
853         }
854 
855         let mut cursor = Cursor::new(src);
856 
857         // If the header frame is malformed, we still have to continue decoding
858         // the headers. A malformed header frame is a stream level error, but
859         // the hpack state is connection level. In order to maintain correct
860         // state for other streams, the hpack decoding process must complete.
861         let res = decoder.decode(&mut cursor, |header| {
862             use crate::hpack::Header::*;
863 
864             match header {
865                 Field { name, value } => {
866                     // Connection level header fields are not supported and must
867                     // result in a protocol error.
868 
869                     if name == header::CONNECTION
870                         || name == header::TRANSFER_ENCODING
871                         || name == header::UPGRADE
872                         || name == "keep-alive"
873                         || name == "proxy-connection"
874                     {
875                         tracing::trace!("load_hpack; connection level header");
876                         malformed = true;
877                     } else if name == header::TE && value != "trailers" {
878                         tracing::trace!(
879                             "load_hpack; TE header not set to trailers; val={:?}",
880                             value
881                         );
882                         malformed = true;
883                     } else {
884                         reg = true;
885 
886                         headers_size += decoded_header_size(name.as_str().len(), value.len());
887                         if headers_size < max_header_list_size {
888                             self.fields.append(name, value);
889                         } else if !self.is_over_size {
890                             tracing::trace!("load_hpack; header list size over max");
891                             self.is_over_size = true;
892                         }
893                     }
894                 }
895                 Authority(v) => set_pseudo!(authority, v),
896                 Method(v) => set_pseudo!(method, v),
897                 Scheme(v) => set_pseudo!(scheme, v),
898                 Path(v) => set_pseudo!(path, v),
899                 Protocol(v) => set_pseudo!(protocol, v),
900                 Status(v) => set_pseudo!(status, v),
901             }
902         });
903 
904         if let Err(e) = res {
905             tracing::trace!("hpack decoding error; err={:?}", e);
906             return Err(e.into());
907         }
908 
909         if malformed {
910             tracing::trace!("malformed message");
911             return Err(Error::MalformedMessage);
912         }
913 
914         Ok(())
915     }
916 
into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock917     fn into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock {
918         let mut hpack = BytesMut::new();
919         let headers = Iter {
920             pseudo: Some(self.pseudo),
921             fields: self.fields.into_iter(),
922         };
923 
924         encoder.encode(headers, &mut hpack);
925 
926         EncodingHeaderBlock {
927             hpack: hpack.freeze(),
928         }
929     }
930 
931     /// Calculates the size of the currently decoded header list.
932     ///
933     /// According to http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE
934     ///
935     /// > The value is based on the uncompressed size of header fields,
936     /// > including the length of the name and value in octets plus an
937     /// > overhead of 32 octets for each header field.
calculate_header_list_size(&self) -> usize938     fn calculate_header_list_size(&self) -> usize {
939         macro_rules! pseudo_size {
940             ($name:ident) => {{
941                 self.pseudo
942                     .$name
943                     .as_ref()
944                     .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
945                     .unwrap_or(0)
946             }};
947         }
948 
949         pseudo_size!(method)
950             + pseudo_size!(scheme)
951             + pseudo_size!(status)
952             + pseudo_size!(authority)
953             + pseudo_size!(path)
954             + self
955                 .fields
956                 .iter()
957                 .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len()))
958                 .sum::<usize>()
959     }
960 }
961 
decoded_header_size(name: usize, value: usize) -> usize962 fn decoded_header_size(name: usize, value: usize) -> usize {
963     name + value + 32
964 }
965 
966 #[cfg(test)]
967 mod test {
968     use std::iter::FromIterator;
969 
970     use http::HeaderValue;
971 
972     use super::*;
973     use crate::frame;
974     use crate::hpack::{huffman, Encoder};
975 
976     #[test]
test_nameless_header_at_resume()977     fn test_nameless_header_at_resume() {
978         let mut encoder = Encoder::default();
979         let mut dst = BytesMut::new();
980 
981         let headers = Headers::new(
982             StreamId::ZERO,
983             Default::default(),
984             HeaderMap::from_iter(vec![
985                 (
986                     HeaderName::from_static("hello"),
987                     HeaderValue::from_static("world"),
988                 ),
989                 (
990                     HeaderName::from_static("hello"),
991                     HeaderValue::from_static("zomg"),
992                 ),
993                 (
994                     HeaderName::from_static("hello"),
995                     HeaderValue::from_static("sup"),
996                 ),
997             ]),
998         );
999 
1000         let continuation = headers
1001             .encode(&mut encoder, &mut (&mut dst).limit(frame::HEADER_LEN + 8))
1002             .unwrap();
1003 
1004         assert_eq!(17, dst.len());
1005         assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]);
1006         assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]);
1007         assert_eq!("hello", huff_decode(&dst[11..15]));
1008         assert_eq!(0x80 | 4, dst[15]);
1009 
1010         let mut world = dst[16..17].to_owned();
1011 
1012         dst.clear();
1013 
1014         assert!(continuation
1015             .encode(&mut (&mut dst).limit(frame::HEADER_LEN + 16))
1016             .is_none());
1017 
1018         world.extend_from_slice(&dst[9..12]);
1019         assert_eq!("world", huff_decode(&world));
1020 
1021         assert_eq!(24, dst.len());
1022         assert_eq!([0, 0, 15, 9, 4, 0, 0, 0, 0], &dst[0..9]);
1023 
1024         // // Next is not indexed
1025         assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]);
1026         assert_eq!("zomg", huff_decode(&dst[15..18]));
1027         assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]);
1028         assert_eq!("sup", huff_decode(&dst[21..]));
1029     }
1030 
huff_decode(src: &[u8]) -> BytesMut1031     fn huff_decode(src: &[u8]) -> BytesMut {
1032         let mut buf = BytesMut::new();
1033         huffman::decode(src, &mut buf).unwrap()
1034     }
1035 }
1036