1 use super::{util, StreamDependency, StreamId};
2 use crate::frame::{Error, Frame, Head, Kind};
3 use crate::hpack::{self, BytesStr};
4
5 use http::header::{self, HeaderName, HeaderValue};
6 use http::{uri, HeaderMap, Method, Request, StatusCode, Uri};
7
8 use bytes::BytesMut;
9
10 use std::fmt;
11 use std::io::Cursor;
12
13 type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>;
14
15 // Minimum MAX_FRAME_SIZE is 16kb, so save some arbitrary space for frame
16 // head and other header bits.
17 const MAX_HEADER_LENGTH: usize = 1024 * 16 - 100;
18
19 /// Header frame
20 ///
21 /// This could be either a request or a response.
22 #[derive(Eq, PartialEq)]
23 pub struct Headers {
24 /// The ID of the stream with which this frame is associated.
25 stream_id: StreamId,
26
27 /// The stream dependency information, if any.
28 stream_dep: Option<StreamDependency>,
29
30 /// The header block fragment
31 header_block: HeaderBlock,
32
33 /// The associated flags
34 flags: HeadersFlag,
35 }
36
37 #[derive(Copy, Clone, Eq, PartialEq)]
38 pub struct HeadersFlag(u8);
39
40 #[derive(Eq, PartialEq)]
41 pub struct PushPromise {
42 /// The ID of the stream with which this frame is associated.
43 stream_id: StreamId,
44
45 /// The ID of the stream being reserved by this PushPromise.
46 promised_id: StreamId,
47
48 /// The header block fragment
49 header_block: HeaderBlock,
50
51 /// The associated flags
52 flags: PushPromiseFlag,
53 }
54
55 #[derive(Copy, Clone, Eq, PartialEq)]
56 pub struct PushPromiseFlag(u8);
57
58 #[derive(Debug)]
59 pub struct Continuation {
60 /// Stream ID of continuation frame
61 stream_id: StreamId,
62
63 header_block: EncodingHeaderBlock,
64 }
65
66 // TODO: These fields shouldn't be `pub`
67 #[derive(Debug, Default, Eq, PartialEq)]
68 pub struct Pseudo {
69 // Request
70 pub method: Option<Method>,
71 pub scheme: Option<BytesStr>,
72 pub authority: Option<BytesStr>,
73 pub path: Option<BytesStr>,
74
75 // Response
76 pub status: Option<StatusCode>,
77 }
78
79 #[derive(Debug)]
80 pub struct Iter {
81 /// Pseudo headers
82 pseudo: Option<Pseudo>,
83
84 /// Header fields
85 fields: header::IntoIter<HeaderValue>,
86 }
87
88 #[derive(Debug, PartialEq, Eq)]
89 struct HeaderBlock {
90 /// The decoded header fields
91 fields: HeaderMap,
92
93 /// Set to true if decoding went over the max header list size.
94 is_over_size: bool,
95
96 /// Pseudo headers, these are broken out as they must be sent as part of the
97 /// headers frame.
98 pseudo: Pseudo,
99 }
100
101 #[derive(Debug)]
102 struct EncodingHeaderBlock {
103 /// Argument to pass to the HPACK encoder to resume encoding
104 hpack: Option<hpack::EncodeState>,
105
106 /// remaining headers to encode
107 headers: Iter,
108 }
109
110 const END_STREAM: u8 = 0x1;
111 const END_HEADERS: u8 = 0x4;
112 const PADDED: u8 = 0x8;
113 const PRIORITY: u8 = 0x20;
114 const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY;
115
116 // ===== impl Headers =====
117
118 impl Headers {
119 /// Create a new HEADERS frame
new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self120 pub fn new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self {
121 Headers {
122 stream_id,
123 stream_dep: None,
124 header_block: HeaderBlock {
125 fields,
126 is_over_size: false,
127 pseudo,
128 },
129 flags: HeadersFlag::default(),
130 }
131 }
132
trailers(stream_id: StreamId, fields: HeaderMap) -> Self133 pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self {
134 let mut flags = HeadersFlag::default();
135 flags.set_end_stream();
136
137 Headers {
138 stream_id,
139 stream_dep: None,
140 header_block: HeaderBlock {
141 fields,
142 is_over_size: false,
143 pseudo: Pseudo::default(),
144 },
145 flags,
146 }
147 }
148
149 /// Loads the header frame but doesn't actually do HPACK decoding.
150 ///
151 /// HPACK decoding is done in the `load_hpack` step.
load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error>152 pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
153 let flags = HeadersFlag(head.flag());
154 let mut pad = 0;
155
156 tracing::trace!("loading headers; flags={:?}", flags);
157
158 // Read the padding length
159 if flags.is_padded() {
160 if src.is_empty() {
161 return Err(Error::MalformedMessage);
162 }
163 pad = src[0] as usize;
164
165 // Drop the padding
166 let _ = src.split_to(1);
167 }
168
169 // Read the stream dependency
170 let stream_dep = if flags.is_priority() {
171 if src.len() < 5 {
172 return Err(Error::MalformedMessage);
173 }
174 let stream_dep = StreamDependency::load(&src[..5])?;
175
176 if stream_dep.dependency_id() == head.stream_id() {
177 return Err(Error::InvalidDependencyId);
178 }
179
180 // Drop the next 5 bytes
181 let _ = src.split_to(5);
182
183 Some(stream_dep)
184 } else {
185 None
186 };
187
188 if pad > 0 {
189 if pad > src.len() {
190 return Err(Error::TooMuchPadding);
191 }
192
193 let len = src.len() - pad;
194 src.truncate(len);
195 }
196
197 let headers = Headers {
198 stream_id: head.stream_id(),
199 stream_dep,
200 header_block: HeaderBlock {
201 fields: HeaderMap::new(),
202 is_over_size: false,
203 pseudo: Pseudo::default(),
204 },
205 flags,
206 };
207
208 Ok((headers, src))
209 }
210
load_hpack( &mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder, ) -> Result<(), Error>211 pub fn load_hpack(
212 &mut self,
213 src: &mut BytesMut,
214 max_header_list_size: usize,
215 decoder: &mut hpack::Decoder,
216 ) -> Result<(), Error> {
217 self.header_block.load(src, max_header_list_size, decoder)
218 }
219
stream_id(&self) -> StreamId220 pub fn stream_id(&self) -> StreamId {
221 self.stream_id
222 }
223
is_end_headers(&self) -> bool224 pub fn is_end_headers(&self) -> bool {
225 self.flags.is_end_headers()
226 }
227
set_end_headers(&mut self)228 pub fn set_end_headers(&mut self) {
229 self.flags.set_end_headers();
230 }
231
is_end_stream(&self) -> bool232 pub fn is_end_stream(&self) -> bool {
233 self.flags.is_end_stream()
234 }
235
set_end_stream(&mut self)236 pub fn set_end_stream(&mut self) {
237 self.flags.set_end_stream()
238 }
239
is_over_size(&self) -> bool240 pub fn is_over_size(&self) -> bool {
241 self.header_block.is_over_size
242 }
243
has_too_big_field(&self) -> bool244 pub(crate) fn has_too_big_field(&self) -> bool {
245 self.header_block.has_too_big_field()
246 }
247
into_parts(self) -> (Pseudo, HeaderMap)248 pub fn into_parts(self) -> (Pseudo, HeaderMap) {
249 (self.header_block.pseudo, self.header_block.fields)
250 }
251
252 #[cfg(feature = "unstable")]
pseudo_mut(&mut self) -> &mut Pseudo253 pub fn pseudo_mut(&mut self) -> &mut Pseudo {
254 &mut self.header_block.pseudo
255 }
256
257 /// Whether it has status 1xx
is_informational(&self) -> bool258 pub(crate) fn is_informational(&self) -> bool {
259 self.header_block.pseudo.is_informational()
260 }
261
fields(&self) -> &HeaderMap262 pub fn fields(&self) -> &HeaderMap {
263 &self.header_block.fields
264 }
265
into_fields(self) -> HeaderMap266 pub fn into_fields(self) -> HeaderMap {
267 self.header_block.fields
268 }
269
encode( self, encoder: &mut hpack::Encoder, dst: &mut EncodeBuf<'_>, ) -> Option<Continuation>270 pub fn encode(
271 self,
272 encoder: &mut hpack::Encoder,
273 dst: &mut EncodeBuf<'_>,
274 ) -> Option<Continuation> {
275 // At this point, the `is_end_headers` flag should always be set
276 debug_assert!(self.flags.is_end_headers());
277
278 // Get the HEADERS frame head
279 let head = self.head();
280
281 self.header_block
282 .into_encoding()
283 .encode(&head, encoder, dst, |_| {})
284 }
285
head(&self) -> Head286 fn head(&self) -> Head {
287 Head::new(Kind::Headers, self.flags.into(), self.stream_id)
288 }
289 }
290
291 impl<T> From<Headers> for Frame<T> {
from(src: Headers) -> Self292 fn from(src: Headers) -> Self {
293 Frame::Headers(src)
294 }
295 }
296
297 impl fmt::Debug for Headers {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result298 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
299 let mut builder = f.debug_struct("Headers");
300 builder
301 .field("stream_id", &self.stream_id)
302 .field("flags", &self.flags);
303
304 if let Some(ref dep) = self.stream_dep {
305 builder.field("stream_dep", dep);
306 }
307
308 // `fields` and `pseudo` purposefully not included
309 builder.finish()
310 }
311 }
312
313 // ===== util =====
314
parse_u64(src: &[u8]) -> Result<u64, ()>315 pub fn parse_u64(src: &[u8]) -> Result<u64, ()> {
316 if src.len() > 19 {
317 // At danger for overflow...
318 return Err(());
319 }
320
321 let mut ret = 0;
322
323 for &d in src {
324 if d < b'0' || d > b'9' {
325 return Err(());
326 }
327
328 ret *= 10;
329 ret += (d - b'0') as u64;
330 }
331
332 Ok(ret)
333 }
334
335 // ===== impl PushPromise =====
336
337 #[derive(Debug)]
338 pub enum PushPromiseHeaderError {
339 InvalidContentLength(Result<u64, ()>),
340 NotSafeAndCacheable,
341 }
342
343 impl PushPromise {
new( stream_id: StreamId, promised_id: StreamId, pseudo: Pseudo, fields: HeaderMap, ) -> Self344 pub fn new(
345 stream_id: StreamId,
346 promised_id: StreamId,
347 pseudo: Pseudo,
348 fields: HeaderMap,
349 ) -> Self {
350 PushPromise {
351 flags: PushPromiseFlag::default(),
352 header_block: HeaderBlock {
353 fields,
354 is_over_size: false,
355 pseudo,
356 },
357 promised_id,
358 stream_id,
359 }
360 }
361
validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError>362 pub fn validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError> {
363 use PushPromiseHeaderError::*;
364 // The spec has some requirements for promised request headers
365 // [https://httpwg.org/specs/rfc7540.html#PushRequests]
366
367 // A promised request "that indicates the presence of a request body
368 // MUST reset the promised stream with a stream error"
369 if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) {
370 let parsed_length = parse_u64(content_length.as_bytes());
371 if parsed_length != Ok(0) {
372 return Err(InvalidContentLength(parsed_length));
373 }
374 }
375 // "The server MUST include a method in the :method pseudo-header field
376 // that is safe and cacheable"
377 if !Self::safe_and_cacheable(req.method()) {
378 return Err(NotSafeAndCacheable);
379 }
380
381 Ok(())
382 }
383
safe_and_cacheable(method: &Method) -> bool384 fn safe_and_cacheable(method: &Method) -> bool {
385 // Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods
386 // Safe: https://httpwg.org/specs/rfc7231.html#safe.methods
387 return method == Method::GET || method == Method::HEAD;
388 }
389
fields(&self) -> &HeaderMap390 pub fn fields(&self) -> &HeaderMap {
391 &self.header_block.fields
392 }
393
394 #[cfg(feature = "unstable")]
into_fields(self) -> HeaderMap395 pub fn into_fields(self) -> HeaderMap {
396 self.header_block.fields
397 }
398
399 /// Loads the push promise frame but doesn't actually do HPACK decoding.
400 ///
401 /// HPACK decoding is done in the `load_hpack` step.
load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error>402 pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
403 let flags = PushPromiseFlag(head.flag());
404 let mut pad = 0;
405
406 // Read the padding length
407 if flags.is_padded() {
408 if src.is_empty() {
409 return Err(Error::MalformedMessage);
410 }
411
412 // TODO: Ensure payload is sized correctly
413 pad = src[0] as usize;
414
415 // Drop the padding
416 let _ = src.split_to(1);
417 }
418
419 if src.len() < 5 {
420 return Err(Error::MalformedMessage);
421 }
422
423 let (promised_id, _) = StreamId::parse(&src[..4]);
424 // Drop promised_id bytes
425 let _ = src.split_to(4);
426
427 if pad > 0 {
428 if pad > src.len() {
429 return Err(Error::TooMuchPadding);
430 }
431
432 let len = src.len() - pad;
433 src.truncate(len);
434 }
435
436 let frame = PushPromise {
437 flags,
438 header_block: HeaderBlock {
439 fields: HeaderMap::new(),
440 is_over_size: false,
441 pseudo: Pseudo::default(),
442 },
443 promised_id,
444 stream_id: head.stream_id(),
445 };
446 Ok((frame, src))
447 }
448
load_hpack( &mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder, ) -> Result<(), Error>449 pub fn load_hpack(
450 &mut self,
451 src: &mut BytesMut,
452 max_header_list_size: usize,
453 decoder: &mut hpack::Decoder,
454 ) -> Result<(), Error> {
455 self.header_block.load(src, max_header_list_size, decoder)
456 }
457
stream_id(&self) -> StreamId458 pub fn stream_id(&self) -> StreamId {
459 self.stream_id
460 }
461
promised_id(&self) -> StreamId462 pub fn promised_id(&self) -> StreamId {
463 self.promised_id
464 }
465
is_end_headers(&self) -> bool466 pub fn is_end_headers(&self) -> bool {
467 self.flags.is_end_headers()
468 }
469
set_end_headers(&mut self)470 pub fn set_end_headers(&mut self) {
471 self.flags.set_end_headers();
472 }
473
is_over_size(&self) -> bool474 pub fn is_over_size(&self) -> bool {
475 self.header_block.is_over_size
476 }
477
encode( self, encoder: &mut hpack::Encoder, dst: &mut EncodeBuf<'_>, ) -> Option<Continuation>478 pub fn encode(
479 self,
480 encoder: &mut hpack::Encoder,
481 dst: &mut EncodeBuf<'_>,
482 ) -> Option<Continuation> {
483 use bytes::BufMut;
484
485 // At this point, the `is_end_headers` flag should always be set
486 debug_assert!(self.flags.is_end_headers());
487
488 let head = self.head();
489 let promised_id = self.promised_id;
490
491 self.header_block
492 .into_encoding()
493 .encode(&head, encoder, dst, |dst| {
494 dst.put_u32(promised_id.into());
495 })
496 }
497
head(&self) -> Head498 fn head(&self) -> Head {
499 Head::new(Kind::PushPromise, self.flags.into(), self.stream_id)
500 }
501
502 /// Consume `self`, returning the parts of the frame
into_parts(self) -> (Pseudo, HeaderMap)503 pub fn into_parts(self) -> (Pseudo, HeaderMap) {
504 (self.header_block.pseudo, self.header_block.fields)
505 }
506 }
507
508 impl<T> From<PushPromise> for Frame<T> {
from(src: PushPromise) -> Self509 fn from(src: PushPromise) -> Self {
510 Frame::PushPromise(src)
511 }
512 }
513
514 impl fmt::Debug for PushPromise {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result515 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
516 f.debug_struct("PushPromise")
517 .field("stream_id", &self.stream_id)
518 .field("promised_id", &self.promised_id)
519 .field("flags", &self.flags)
520 // `fields` and `pseudo` purposefully not included
521 .finish()
522 }
523 }
524
525 // ===== impl Continuation =====
526
527 impl Continuation {
head(&self) -> Head528 fn head(&self) -> Head {
529 Head::new(Kind::Continuation, END_HEADERS, self.stream_id)
530 }
531
encode( self, encoder: &mut hpack::Encoder, dst: &mut EncodeBuf<'_>, ) -> Option<Continuation>532 pub fn encode(
533 self,
534 encoder: &mut hpack::Encoder,
535 dst: &mut EncodeBuf<'_>,
536 ) -> Option<Continuation> {
537 // Get the CONTINUATION frame head
538 let head = self.head();
539
540 self.header_block.encode(&head, encoder, dst, |_| {})
541 }
542 }
543
544 // ===== impl Pseudo =====
545
546 impl Pseudo {
request(method: Method, uri: Uri) -> Self547 pub fn request(method: Method, uri: Uri) -> Self {
548 let parts = uri::Parts::from(uri);
549
550 let mut path = parts
551 .path_and_query
552 .map(|v| BytesStr::from(v.as_str()))
553 .unwrap_or(BytesStr::from_static(""));
554
555 match method {
556 Method::OPTIONS | Method::CONNECT => {}
557 _ if path.is_empty() => {
558 path = BytesStr::from_static("/");
559 }
560 _ => {}
561 }
562
563 let mut pseudo = Pseudo {
564 method: Some(method),
565 scheme: None,
566 authority: None,
567 path: Some(path).filter(|p| !p.is_empty()),
568 status: None,
569 };
570
571 // If the URI includes a scheme component, add it to the pseudo headers
572 //
573 // TODO: Scheme must be set...
574 if let Some(scheme) = parts.scheme {
575 pseudo.set_scheme(scheme);
576 }
577
578 // If the URI includes an authority component, add it to the pseudo
579 // headers
580 if let Some(authority) = parts.authority {
581 pseudo.set_authority(BytesStr::from(authority.as_str()));
582 }
583
584 pseudo
585 }
586
response(status: StatusCode) -> Self587 pub fn response(status: StatusCode) -> Self {
588 Pseudo {
589 method: None,
590 scheme: None,
591 authority: None,
592 path: None,
593 status: Some(status),
594 }
595 }
596
set_scheme(&mut self, scheme: uri::Scheme)597 pub fn set_scheme(&mut self, scheme: uri::Scheme) {
598 let bytes_str = match scheme.as_str() {
599 "http" => BytesStr::from_static("http"),
600 "https" => BytesStr::from_static("https"),
601 s => BytesStr::from(s),
602 };
603 self.scheme = Some(bytes_str);
604 }
605
set_authority(&mut self, authority: BytesStr)606 pub fn set_authority(&mut self, authority: BytesStr) {
607 self.authority = Some(authority);
608 }
609
610 /// Whether it has status 1xx
is_informational(&self) -> bool611 pub(crate) fn is_informational(&self) -> bool {
612 self.status
613 .map_or(false, |status| status.is_informational())
614 }
615 }
616
617 // ===== impl EncodingHeaderBlock =====
618
619 impl EncodingHeaderBlock {
encode<F>( mut self, head: &Head, encoder: &mut hpack::Encoder, dst: &mut EncodeBuf<'_>, f: F, ) -> Option<Continuation> where F: FnOnce(&mut EncodeBuf<'_>),620 fn encode<F>(
621 mut self,
622 head: &Head,
623 encoder: &mut hpack::Encoder,
624 dst: &mut EncodeBuf<'_>,
625 f: F,
626 ) -> Option<Continuation>
627 where
628 F: FnOnce(&mut EncodeBuf<'_>),
629 {
630 let head_pos = dst.get_ref().len();
631
632 // At this point, we don't know how big the h2 frame will be.
633 // So, we write the head with length 0, then write the body, and
634 // finally write the length once we know the size.
635 head.encode(0, dst);
636
637 let payload_pos = dst.get_ref().len();
638
639 f(dst);
640
641 // Now, encode the header payload
642 let continuation = match encoder.encode(self.hpack, &mut self.headers, dst) {
643 hpack::Encode::Full => None,
644 hpack::Encode::Partial(state) => Some(Continuation {
645 stream_id: head.stream_id(),
646 header_block: EncodingHeaderBlock {
647 hpack: Some(state),
648 headers: self.headers,
649 },
650 }),
651 };
652
653 // Compute the header block length
654 let payload_len = (dst.get_ref().len() - payload_pos) as u64;
655
656 // Write the frame length
657 let payload_len_be = payload_len.to_be_bytes();
658 assert!(payload_len_be[0..5].iter().all(|b| *b == 0));
659 (dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]);
660
661 if continuation.is_some() {
662 // There will be continuation frames, so the `is_end_headers` flag
663 // must be unset
664 debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS);
665
666 dst.get_mut()[head_pos + 4] -= END_HEADERS;
667 }
668
669 continuation
670 }
671 }
672
673 // ===== impl Iter =====
674
675 impl Iterator for Iter {
676 type Item = hpack::Header<Option<HeaderName>>;
677
next(&mut self) -> Option<Self::Item>678 fn next(&mut self) -> Option<Self::Item> {
679 use crate::hpack::Header::*;
680
681 if let Some(ref mut pseudo) = self.pseudo {
682 if let Some(method) = pseudo.method.take() {
683 return Some(Method(method));
684 }
685
686 if let Some(scheme) = pseudo.scheme.take() {
687 return Some(Scheme(scheme));
688 }
689
690 if let Some(authority) = pseudo.authority.take() {
691 return Some(Authority(authority));
692 }
693
694 if let Some(path) = pseudo.path.take() {
695 return Some(Path(path));
696 }
697
698 if let Some(status) = pseudo.status.take() {
699 return Some(Status(status));
700 }
701 }
702
703 self.pseudo = None;
704
705 self.fields
706 .next()
707 .map(|(name, value)| Field { name, value })
708 }
709 }
710
711 // ===== impl HeadersFlag =====
712
713 impl HeadersFlag {
empty() -> HeadersFlag714 pub fn empty() -> HeadersFlag {
715 HeadersFlag(0)
716 }
717
load(bits: u8) -> HeadersFlag718 pub fn load(bits: u8) -> HeadersFlag {
719 HeadersFlag(bits & ALL)
720 }
721
is_end_stream(&self) -> bool722 pub fn is_end_stream(&self) -> bool {
723 self.0 & END_STREAM == END_STREAM
724 }
725
set_end_stream(&mut self)726 pub fn set_end_stream(&mut self) {
727 self.0 |= END_STREAM;
728 }
729
is_end_headers(&self) -> bool730 pub fn is_end_headers(&self) -> bool {
731 self.0 & END_HEADERS == END_HEADERS
732 }
733
set_end_headers(&mut self)734 pub fn set_end_headers(&mut self) {
735 self.0 |= END_HEADERS;
736 }
737
is_padded(&self) -> bool738 pub fn is_padded(&self) -> bool {
739 self.0 & PADDED == PADDED
740 }
741
is_priority(&self) -> bool742 pub fn is_priority(&self) -> bool {
743 self.0 & PRIORITY == PRIORITY
744 }
745 }
746
747 impl Default for HeadersFlag {
748 /// Returns a `HeadersFlag` value with `END_HEADERS` set.
default() -> Self749 fn default() -> Self {
750 HeadersFlag(END_HEADERS)
751 }
752 }
753
754 impl From<HeadersFlag> for u8 {
from(src: HeadersFlag) -> u8755 fn from(src: HeadersFlag) -> u8 {
756 src.0
757 }
758 }
759
760 impl fmt::Debug for HeadersFlag {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result761 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
762 util::debug_flags(fmt, self.0)
763 .flag_if(self.is_end_headers(), "END_HEADERS")
764 .flag_if(self.is_end_stream(), "END_STREAM")
765 .flag_if(self.is_padded(), "PADDED")
766 .flag_if(self.is_priority(), "PRIORITY")
767 .finish()
768 }
769 }
770
771 // ===== impl PushPromiseFlag =====
772
773 impl PushPromiseFlag {
empty() -> PushPromiseFlag774 pub fn empty() -> PushPromiseFlag {
775 PushPromiseFlag(0)
776 }
777
load(bits: u8) -> PushPromiseFlag778 pub fn load(bits: u8) -> PushPromiseFlag {
779 PushPromiseFlag(bits & ALL)
780 }
781
is_end_headers(&self) -> bool782 pub fn is_end_headers(&self) -> bool {
783 self.0 & END_HEADERS == END_HEADERS
784 }
785
set_end_headers(&mut self)786 pub fn set_end_headers(&mut self) {
787 self.0 |= END_HEADERS;
788 }
789
is_padded(&self) -> bool790 pub fn is_padded(&self) -> bool {
791 self.0 & PADDED == PADDED
792 }
793 }
794
795 impl Default for PushPromiseFlag {
796 /// Returns a `PushPromiseFlag` value with `END_HEADERS` set.
default() -> Self797 fn default() -> Self {
798 PushPromiseFlag(END_HEADERS)
799 }
800 }
801
802 impl From<PushPromiseFlag> for u8 {
from(src: PushPromiseFlag) -> u8803 fn from(src: PushPromiseFlag) -> u8 {
804 src.0
805 }
806 }
807
808 impl fmt::Debug for PushPromiseFlag {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result809 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
810 util::debug_flags(fmt, self.0)
811 .flag_if(self.is_end_headers(), "END_HEADERS")
812 .flag_if(self.is_padded(), "PADDED")
813 .finish()
814 }
815 }
816
817 // ===== HeaderBlock =====
818
819 impl HeaderBlock {
load( &mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder, ) -> Result<(), Error>820 fn load(
821 &mut self,
822 src: &mut BytesMut,
823 max_header_list_size: usize,
824 decoder: &mut hpack::Decoder,
825 ) -> Result<(), Error> {
826 let mut reg = !self.fields.is_empty();
827 let mut malformed = false;
828 let mut headers_size = self.calculate_header_list_size();
829
830 macro_rules! set_pseudo {
831 ($field:ident, $val:expr) => {{
832 if reg {
833 tracing::trace!("load_hpack; header malformed -- pseudo not at head of block");
834 malformed = true;
835 } else if self.pseudo.$field.is_some() {
836 tracing::trace!("load_hpack; header malformed -- repeated pseudo");
837 malformed = true;
838 } else {
839 let __val = $val;
840 headers_size +=
841 decoded_header_size(stringify!($field).len() + 1, __val.as_str().len());
842 if headers_size < max_header_list_size {
843 self.pseudo.$field = Some(__val);
844 } else if !self.is_over_size {
845 tracing::trace!("load_hpack; header list size over max");
846 self.is_over_size = true;
847 }
848 }
849 }};
850 }
851
852 let mut cursor = Cursor::new(src);
853
854 // If the header frame is malformed, we still have to continue decoding
855 // the headers. A malformed header frame is a stream level error, but
856 // the hpack state is connection level. In order to maintain correct
857 // state for other streams, the hpack decoding process must complete.
858 let res = decoder.decode(&mut cursor, |header| {
859 use crate::hpack::Header::*;
860
861 match header {
862 Field { name, value } => {
863 // Connection level header fields are not supported and must
864 // result in a protocol error.
865
866 if name == header::CONNECTION
867 || name == header::TRANSFER_ENCODING
868 || name == header::UPGRADE
869 || name == "keep-alive"
870 || name == "proxy-connection"
871 {
872 tracing::trace!("load_hpack; connection level header");
873 malformed = true;
874 } else if name == header::TE && value != "trailers" {
875 tracing::trace!(
876 "load_hpack; TE header not set to trailers; val={:?}",
877 value
878 );
879 malformed = true;
880 } else {
881 reg = true;
882
883 headers_size += decoded_header_size(name.as_str().len(), value.len());
884 if headers_size < max_header_list_size {
885 self.fields.append(name, value);
886 } else if !self.is_over_size {
887 tracing::trace!("load_hpack; header list size over max");
888 self.is_over_size = true;
889 }
890 }
891 }
892 Authority(v) => set_pseudo!(authority, v),
893 Method(v) => set_pseudo!(method, v),
894 Scheme(v) => set_pseudo!(scheme, v),
895 Path(v) => set_pseudo!(path, v),
896 Status(v) => set_pseudo!(status, v),
897 }
898 });
899
900 if let Err(e) = res {
901 tracing::trace!("hpack decoding error; err={:?}", e);
902 return Err(e.into());
903 }
904
905 if malformed {
906 tracing::trace!("malformed message");
907 return Err(Error::MalformedMessage);
908 }
909
910 Ok(())
911 }
912
into_encoding(self) -> EncodingHeaderBlock913 fn into_encoding(self) -> EncodingHeaderBlock {
914 EncodingHeaderBlock {
915 hpack: None,
916 headers: Iter {
917 pseudo: Some(self.pseudo),
918 fields: self.fields.into_iter(),
919 },
920 }
921 }
922
923 /// Calculates the size of the currently decoded header list.
924 ///
925 /// According to http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE
926 ///
927 /// > The value is based on the uncompressed size of header fields,
928 /// > including the length of the name and value in octets plus an
929 /// > overhead of 32 octets for each header field.
calculate_header_list_size(&self) -> usize930 fn calculate_header_list_size(&self) -> usize {
931 macro_rules! pseudo_size {
932 ($name:ident) => {{
933 self.pseudo
934 .$name
935 .as_ref()
936 .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
937 .unwrap_or(0)
938 }};
939 }
940
941 pseudo_size!(method)
942 + pseudo_size!(scheme)
943 + pseudo_size!(status)
944 + pseudo_size!(authority)
945 + pseudo_size!(path)
946 + self
947 .fields
948 .iter()
949 .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len()))
950 .sum::<usize>()
951 }
952
953 /// Iterate over all pseudos and headers to see if any individual pair
954 /// would be too large to encode.
has_too_big_field(&self) -> bool955 pub(crate) fn has_too_big_field(&self) -> bool {
956 macro_rules! pseudo_size {
957 ($name:ident) => {{
958 self.pseudo
959 .$name
960 .as_ref()
961 .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
962 .unwrap_or(0)
963 }};
964 }
965
966 if pseudo_size!(method) > MAX_HEADER_LENGTH {
967 return true;
968 }
969
970 if pseudo_size!(scheme) > MAX_HEADER_LENGTH {
971 return true;
972 }
973
974 if pseudo_size!(authority) > MAX_HEADER_LENGTH {
975 return true;
976 }
977
978 if pseudo_size!(path) > MAX_HEADER_LENGTH {
979 return true;
980 }
981
982 // skip :status, its never going to be too big
983
984 for (name, value) in &self.fields {
985 if decoded_header_size(name.as_str().len(), value.len()) > MAX_HEADER_LENGTH {
986 return true;
987 }
988 }
989
990 false
991 }
992 }
993
decoded_header_size(name: usize, value: usize) -> usize994 fn decoded_header_size(name: usize, value: usize) -> usize {
995 name + value + 32
996 }
997