1 use super::{util, StreamDependency, StreamId};
2 use crate::ext::Protocol;
3 use crate::frame::{Error, Frame, Head, Kind};
4 use crate::hpack::{self, BytesStr};
5
6 use http::header::{self, HeaderName, HeaderValue};
7 use http::{uri, HeaderMap, Method, Request, StatusCode, Uri};
8
9 use bytes::{BufMut, Bytes, BytesMut};
10
11 use std::fmt;
12 use std::io::Cursor;
13
14 type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>;
15 /// Header frame
16 ///
17 /// This could be either a request or a response.
18 #[derive(Eq, PartialEq)]
19 pub struct Headers {
20 /// The ID of the stream with which this frame is associated.
21 stream_id: StreamId,
22
23 /// The stream dependency information, if any.
24 stream_dep: Option<StreamDependency>,
25
26 /// The header block fragment
27 header_block: HeaderBlock,
28
29 /// The associated flags
30 flags: HeadersFlag,
31 }
32
33 #[derive(Copy, Clone, Eq, PartialEq)]
34 pub struct HeadersFlag(u8);
35
36 #[derive(Eq, PartialEq)]
37 pub struct PushPromise {
38 /// The ID of the stream with which this frame is associated.
39 stream_id: StreamId,
40
41 /// The ID of the stream being reserved by this PushPromise.
42 promised_id: StreamId,
43
44 /// The header block fragment
45 header_block: HeaderBlock,
46
47 /// The associated flags
48 flags: PushPromiseFlag,
49 }
50
51 #[derive(Copy, Clone, Eq, PartialEq)]
52 pub struct PushPromiseFlag(u8);
53
54 #[derive(Debug)]
55 pub struct Continuation {
56 /// Stream ID of continuation frame
57 stream_id: StreamId,
58
59 header_block: EncodingHeaderBlock,
60 }
61
62 // TODO: These fields shouldn't be `pub`
63 #[derive(Debug, Default, Eq, PartialEq)]
64 pub struct Pseudo {
65 // Request
66 pub method: Option<Method>,
67 pub scheme: Option<BytesStr>,
68 pub authority: Option<BytesStr>,
69 pub path: Option<BytesStr>,
70 pub protocol: Option<Protocol>,
71
72 // Response
73 pub status: Option<StatusCode>,
74 }
75
76 #[derive(Debug)]
77 pub struct Iter {
78 /// Pseudo headers
79 pseudo: Option<Pseudo>,
80
81 /// Header fields
82 fields: header::IntoIter<HeaderValue>,
83 }
84
85 #[derive(Debug, PartialEq, Eq)]
86 struct HeaderBlock {
87 /// The decoded header fields
88 fields: HeaderMap,
89
90 /// Set to true if decoding went over the max header list size.
91 is_over_size: bool,
92
93 /// Pseudo headers, these are broken out as they must be sent as part of the
94 /// headers frame.
95 pseudo: Pseudo,
96 }
97
98 #[derive(Debug)]
99 struct EncodingHeaderBlock {
100 hpack: Bytes,
101 }
102
103 const END_STREAM: u8 = 0x1;
104 const END_HEADERS: u8 = 0x4;
105 const PADDED: u8 = 0x8;
106 const PRIORITY: u8 = 0x20;
107 const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY;
108
109 // ===== impl Headers =====
110
111 impl Headers {
112 /// Create a new HEADERS frame
new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self113 pub fn new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self {
114 Headers {
115 stream_id,
116 stream_dep: None,
117 header_block: HeaderBlock {
118 fields,
119 is_over_size: false,
120 pseudo,
121 },
122 flags: HeadersFlag::default(),
123 }
124 }
125
trailers(stream_id: StreamId, fields: HeaderMap) -> Self126 pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self {
127 let mut flags = HeadersFlag::default();
128 flags.set_end_stream();
129
130 Headers {
131 stream_id,
132 stream_dep: None,
133 header_block: HeaderBlock {
134 fields,
135 is_over_size: false,
136 pseudo: Pseudo::default(),
137 },
138 flags,
139 }
140 }
141
142 /// Loads the header frame but doesn't actually do HPACK decoding.
143 ///
144 /// HPACK decoding is done in the `load_hpack` step.
load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error>145 pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
146 let flags = HeadersFlag(head.flag());
147 let mut pad = 0;
148
149 tracing::trace!("loading headers; flags={:?}", flags);
150
151 if head.stream_id().is_zero() {
152 return Err(Error::InvalidStreamId);
153 }
154
155 // Read the padding length
156 if flags.is_padded() {
157 if src.is_empty() {
158 return Err(Error::MalformedMessage);
159 }
160 pad = src[0] as usize;
161
162 // Drop the padding
163 let _ = src.split_to(1);
164 }
165
166 // Read the stream dependency
167 let stream_dep = if flags.is_priority() {
168 if src.len() < 5 {
169 return Err(Error::MalformedMessage);
170 }
171 let stream_dep = StreamDependency::load(&src[..5])?;
172
173 if stream_dep.dependency_id() == head.stream_id() {
174 return Err(Error::InvalidDependencyId);
175 }
176
177 // Drop the next 5 bytes
178 let _ = src.split_to(5);
179
180 Some(stream_dep)
181 } else {
182 None
183 };
184
185 if pad > 0 {
186 if pad > src.len() {
187 return Err(Error::TooMuchPadding);
188 }
189
190 let len = src.len() - pad;
191 src.truncate(len);
192 }
193
194 let headers = Headers {
195 stream_id: head.stream_id(),
196 stream_dep,
197 header_block: HeaderBlock {
198 fields: HeaderMap::new(),
199 is_over_size: false,
200 pseudo: Pseudo::default(),
201 },
202 flags,
203 };
204
205 Ok((headers, src))
206 }
207
load_hpack( &mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder, ) -> Result<(), Error>208 pub fn load_hpack(
209 &mut self,
210 src: &mut BytesMut,
211 max_header_list_size: usize,
212 decoder: &mut hpack::Decoder,
213 ) -> Result<(), Error> {
214 self.header_block.load(src, max_header_list_size, decoder)
215 }
216
stream_id(&self) -> StreamId217 pub fn stream_id(&self) -> StreamId {
218 self.stream_id
219 }
220
is_end_headers(&self) -> bool221 pub fn is_end_headers(&self) -> bool {
222 self.flags.is_end_headers()
223 }
224
set_end_headers(&mut self)225 pub fn set_end_headers(&mut self) {
226 self.flags.set_end_headers();
227 }
228
is_end_stream(&self) -> bool229 pub fn is_end_stream(&self) -> bool {
230 self.flags.is_end_stream()
231 }
232
set_end_stream(&mut self)233 pub fn set_end_stream(&mut self) {
234 self.flags.set_end_stream()
235 }
236
is_over_size(&self) -> bool237 pub fn is_over_size(&self) -> bool {
238 self.header_block.is_over_size
239 }
240
into_parts(self) -> (Pseudo, HeaderMap)241 pub fn into_parts(self) -> (Pseudo, HeaderMap) {
242 (self.header_block.pseudo, self.header_block.fields)
243 }
244
245 #[cfg(feature = "unstable")]
pseudo_mut(&mut self) -> &mut Pseudo246 pub fn pseudo_mut(&mut self) -> &mut Pseudo {
247 &mut self.header_block.pseudo
248 }
249
250 /// Whether it has status 1xx
is_informational(&self) -> bool251 pub(crate) fn is_informational(&self) -> bool {
252 self.header_block.pseudo.is_informational()
253 }
254
fields(&self) -> &HeaderMap255 pub fn fields(&self) -> &HeaderMap {
256 &self.header_block.fields
257 }
258
into_fields(self) -> HeaderMap259 pub fn into_fields(self) -> HeaderMap {
260 self.header_block.fields
261 }
262
encode( self, encoder: &mut hpack::Encoder, dst: &mut EncodeBuf<'_>, ) -> Option<Continuation>263 pub fn encode(
264 self,
265 encoder: &mut hpack::Encoder,
266 dst: &mut EncodeBuf<'_>,
267 ) -> Option<Continuation> {
268 // At this point, the `is_end_headers` flag should always be set
269 debug_assert!(self.flags.is_end_headers());
270
271 // Get the HEADERS frame head
272 let head = self.head();
273
274 self.header_block
275 .into_encoding(encoder)
276 .encode(&head, dst, |_| {})
277 }
278
head(&self) -> Head279 fn head(&self) -> Head {
280 Head::new(Kind::Headers, self.flags.into(), self.stream_id)
281 }
282 }
283
284 impl<T> From<Headers> for Frame<T> {
from(src: Headers) -> Self285 fn from(src: Headers) -> Self {
286 Frame::Headers(src)
287 }
288 }
289
290 impl fmt::Debug for Headers {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result291 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
292 let mut builder = f.debug_struct("Headers");
293 builder
294 .field("stream_id", &self.stream_id)
295 .field("flags", &self.flags);
296
297 if let Some(ref protocol) = self.header_block.pseudo.protocol {
298 builder.field("protocol", protocol);
299 }
300
301 if let Some(ref dep) = self.stream_dep {
302 builder.field("stream_dep", dep);
303 }
304
305 // `fields` and `pseudo` purposefully not included
306 builder.finish()
307 }
308 }
309
310 // ===== util =====
311
parse_u64(src: &[u8]) -> Result<u64, ()>312 pub fn parse_u64(src: &[u8]) -> Result<u64, ()> {
313 if src.len() > 19 {
314 // At danger for overflow...
315 return Err(());
316 }
317
318 let mut ret = 0;
319
320 for &d in src {
321 if d < b'0' || d > b'9' {
322 return Err(());
323 }
324
325 ret *= 10;
326 ret += (d - b'0') as u64;
327 }
328
329 Ok(ret)
330 }
331
332 // ===== impl PushPromise =====
333
334 #[derive(Debug)]
335 pub enum PushPromiseHeaderError {
336 InvalidContentLength(Result<u64, ()>),
337 NotSafeAndCacheable,
338 }
339
340 impl PushPromise {
new( stream_id: StreamId, promised_id: StreamId, pseudo: Pseudo, fields: HeaderMap, ) -> Self341 pub fn new(
342 stream_id: StreamId,
343 promised_id: StreamId,
344 pseudo: Pseudo,
345 fields: HeaderMap,
346 ) -> Self {
347 PushPromise {
348 flags: PushPromiseFlag::default(),
349 header_block: HeaderBlock {
350 fields,
351 is_over_size: false,
352 pseudo,
353 },
354 promised_id,
355 stream_id,
356 }
357 }
358
validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError>359 pub fn validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError> {
360 use PushPromiseHeaderError::*;
361 // The spec has some requirements for promised request headers
362 // [https://httpwg.org/specs/rfc7540.html#PushRequests]
363
364 // A promised request "that indicates the presence of a request body
365 // MUST reset the promised stream with a stream error"
366 if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) {
367 let parsed_length = parse_u64(content_length.as_bytes());
368 if parsed_length != Ok(0) {
369 return Err(InvalidContentLength(parsed_length));
370 }
371 }
372 // "The server MUST include a method in the :method pseudo-header field
373 // that is safe and cacheable"
374 if !Self::safe_and_cacheable(req.method()) {
375 return Err(NotSafeAndCacheable);
376 }
377
378 Ok(())
379 }
380
safe_and_cacheable(method: &Method) -> bool381 fn safe_and_cacheable(method: &Method) -> bool {
382 // Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods
383 // Safe: https://httpwg.org/specs/rfc7231.html#safe.methods
384 return method == Method::GET || method == Method::HEAD;
385 }
386
fields(&self) -> &HeaderMap387 pub fn fields(&self) -> &HeaderMap {
388 &self.header_block.fields
389 }
390
391 #[cfg(feature = "unstable")]
into_fields(self) -> HeaderMap392 pub fn into_fields(self) -> HeaderMap {
393 self.header_block.fields
394 }
395
396 /// Loads the push promise frame but doesn't actually do HPACK decoding.
397 ///
398 /// HPACK decoding is done in the `load_hpack` step.
load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error>399 pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
400 let flags = PushPromiseFlag(head.flag());
401 let mut pad = 0;
402
403 // Read the padding length
404 if flags.is_padded() {
405 if src.is_empty() {
406 return Err(Error::MalformedMessage);
407 }
408
409 // TODO: Ensure payload is sized correctly
410 pad = src[0] as usize;
411
412 // Drop the padding
413 let _ = src.split_to(1);
414 }
415
416 if src.len() < 5 {
417 return Err(Error::MalformedMessage);
418 }
419
420 let (promised_id, _) = StreamId::parse(&src[..4]);
421 // Drop promised_id bytes
422 let _ = src.split_to(4);
423
424 if pad > 0 {
425 if pad > src.len() {
426 return Err(Error::TooMuchPadding);
427 }
428
429 let len = src.len() - pad;
430 src.truncate(len);
431 }
432
433 let frame = PushPromise {
434 flags,
435 header_block: HeaderBlock {
436 fields: HeaderMap::new(),
437 is_over_size: false,
438 pseudo: Pseudo::default(),
439 },
440 promised_id,
441 stream_id: head.stream_id(),
442 };
443 Ok((frame, src))
444 }
445
load_hpack( &mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder, ) -> Result<(), Error>446 pub fn load_hpack(
447 &mut self,
448 src: &mut BytesMut,
449 max_header_list_size: usize,
450 decoder: &mut hpack::Decoder,
451 ) -> Result<(), Error> {
452 self.header_block.load(src, max_header_list_size, decoder)
453 }
454
stream_id(&self) -> StreamId455 pub fn stream_id(&self) -> StreamId {
456 self.stream_id
457 }
458
promised_id(&self) -> StreamId459 pub fn promised_id(&self) -> StreamId {
460 self.promised_id
461 }
462
is_end_headers(&self) -> bool463 pub fn is_end_headers(&self) -> bool {
464 self.flags.is_end_headers()
465 }
466
set_end_headers(&mut self)467 pub fn set_end_headers(&mut self) {
468 self.flags.set_end_headers();
469 }
470
is_over_size(&self) -> bool471 pub fn is_over_size(&self) -> bool {
472 self.header_block.is_over_size
473 }
474
encode( self, encoder: &mut hpack::Encoder, dst: &mut EncodeBuf<'_>, ) -> Option<Continuation>475 pub fn encode(
476 self,
477 encoder: &mut hpack::Encoder,
478 dst: &mut EncodeBuf<'_>,
479 ) -> Option<Continuation> {
480 // At this point, the `is_end_headers` flag should always be set
481 debug_assert!(self.flags.is_end_headers());
482
483 let head = self.head();
484 let promised_id = self.promised_id;
485
486 self.header_block
487 .into_encoding(encoder)
488 .encode(&head, dst, |dst| {
489 dst.put_u32(promised_id.into());
490 })
491 }
492
head(&self) -> Head493 fn head(&self) -> Head {
494 Head::new(Kind::PushPromise, self.flags.into(), self.stream_id)
495 }
496
497 /// Consume `self`, returning the parts of the frame
into_parts(self) -> (Pseudo, HeaderMap)498 pub fn into_parts(self) -> (Pseudo, HeaderMap) {
499 (self.header_block.pseudo, self.header_block.fields)
500 }
501 }
502
503 impl<T> From<PushPromise> for Frame<T> {
from(src: PushPromise) -> Self504 fn from(src: PushPromise) -> Self {
505 Frame::PushPromise(src)
506 }
507 }
508
509 impl fmt::Debug for PushPromise {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result510 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
511 f.debug_struct("PushPromise")
512 .field("stream_id", &self.stream_id)
513 .field("promised_id", &self.promised_id)
514 .field("flags", &self.flags)
515 // `fields` and `pseudo` purposefully not included
516 .finish()
517 }
518 }
519
520 // ===== impl Continuation =====
521
522 impl Continuation {
head(&self) -> Head523 fn head(&self) -> Head {
524 Head::new(Kind::Continuation, END_HEADERS, self.stream_id)
525 }
526
encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation>527 pub fn encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation> {
528 // Get the CONTINUATION frame head
529 let head = self.head();
530
531 self.header_block.encode(&head, dst, |_| {})
532 }
533 }
534
535 // ===== impl Pseudo =====
536
537 impl Pseudo {
request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self538 pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self {
539 let parts = uri::Parts::from(uri);
540
541 let mut path = parts
542 .path_and_query
543 .map(|v| BytesStr::from(v.as_str()))
544 .unwrap_or(BytesStr::from_static(""));
545
546 match method {
547 Method::OPTIONS | Method::CONNECT => {}
548 _ if path.is_empty() => {
549 path = BytesStr::from_static("/");
550 }
551 _ => {}
552 }
553
554 let mut pseudo = Pseudo {
555 method: Some(method),
556 scheme: None,
557 authority: None,
558 path: Some(path).filter(|p| !p.is_empty()),
559 protocol,
560 status: None,
561 };
562
563 // If the URI includes a scheme component, add it to the pseudo headers
564 //
565 // TODO: Scheme must be set...
566 if let Some(scheme) = parts.scheme {
567 pseudo.set_scheme(scheme);
568 }
569
570 // If the URI includes an authority component, add it to the pseudo
571 // headers
572 if let Some(authority) = parts.authority {
573 pseudo.set_authority(BytesStr::from(authority.as_str()));
574 }
575
576 pseudo
577 }
578
response(status: StatusCode) -> Self579 pub fn response(status: StatusCode) -> Self {
580 Pseudo {
581 method: None,
582 scheme: None,
583 authority: None,
584 path: None,
585 protocol: None,
586 status: Some(status),
587 }
588 }
589
590 #[cfg(feature = "unstable")]
set_status(&mut self, value: StatusCode)591 pub fn set_status(&mut self, value: StatusCode) {
592 self.status = Some(value);
593 }
594
set_scheme(&mut self, scheme: uri::Scheme)595 pub fn set_scheme(&mut self, scheme: uri::Scheme) {
596 let bytes_str = match scheme.as_str() {
597 "http" => BytesStr::from_static("http"),
598 "https" => BytesStr::from_static("https"),
599 s => BytesStr::from(s),
600 };
601 self.scheme = Some(bytes_str);
602 }
603
604 #[cfg(feature = "unstable")]
set_protocol(&mut self, protocol: Protocol)605 pub fn set_protocol(&mut self, protocol: Protocol) {
606 self.protocol = Some(protocol);
607 }
608
set_authority(&mut self, authority: BytesStr)609 pub fn set_authority(&mut self, authority: BytesStr) {
610 self.authority = Some(authority);
611 }
612
613 /// Whether it has status 1xx
is_informational(&self) -> bool614 pub(crate) fn is_informational(&self) -> bool {
615 self.status
616 .map_or(false, |status| status.is_informational())
617 }
618 }
619
620 // ===== impl EncodingHeaderBlock =====
621
622 impl EncodingHeaderBlock {
encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation> where F: FnOnce(&mut EncodeBuf<'_>),623 fn encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation>
624 where
625 F: FnOnce(&mut EncodeBuf<'_>),
626 {
627 let head_pos = dst.get_ref().len();
628
629 // At this point, we don't know how big the h2 frame will be.
630 // So, we write the head with length 0, then write the body, and
631 // finally write the length once we know the size.
632 head.encode(0, dst);
633
634 let payload_pos = dst.get_ref().len();
635
636 f(dst);
637
638 // Now, encode the header payload
639 let continuation = if self.hpack.len() > dst.remaining_mut() {
640 dst.put_slice(&self.hpack.split_to(dst.remaining_mut()));
641
642 Some(Continuation {
643 stream_id: head.stream_id(),
644 header_block: self,
645 })
646 } else {
647 dst.put_slice(&self.hpack);
648
649 None
650 };
651
652 // Compute the header block length
653 let payload_len = (dst.get_ref().len() - payload_pos) as u64;
654
655 // Write the frame length
656 let payload_len_be = payload_len.to_be_bytes();
657 assert!(payload_len_be[0..5].iter().all(|b| *b == 0));
658 (dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]);
659
660 if continuation.is_some() {
661 // There will be continuation frames, so the `is_end_headers` flag
662 // must be unset
663 debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS);
664
665 dst.get_mut()[head_pos + 4] -= END_HEADERS;
666 }
667
668 continuation
669 }
670 }
671
672 // ===== impl Iter =====
673
674 impl Iterator for Iter {
675 type Item = hpack::Header<Option<HeaderName>>;
676
next(&mut self) -> Option<Self::Item>677 fn next(&mut self) -> Option<Self::Item> {
678 use crate::hpack::Header::*;
679
680 if let Some(ref mut pseudo) = self.pseudo {
681 if let Some(method) = pseudo.method.take() {
682 return Some(Method(method));
683 }
684
685 if let Some(scheme) = pseudo.scheme.take() {
686 return Some(Scheme(scheme));
687 }
688
689 if let Some(authority) = pseudo.authority.take() {
690 return Some(Authority(authority));
691 }
692
693 if let Some(path) = pseudo.path.take() {
694 return Some(Path(path));
695 }
696
697 if let Some(protocol) = pseudo.protocol.take() {
698 return Some(Protocol(protocol));
699 }
700
701 if let Some(status) = pseudo.status.take() {
702 return Some(Status(status));
703 }
704 }
705
706 self.pseudo = None;
707
708 self.fields
709 .next()
710 .map(|(name, value)| Field { name, value })
711 }
712 }
713
714 // ===== impl HeadersFlag =====
715
716 impl HeadersFlag {
empty() -> HeadersFlag717 pub fn empty() -> HeadersFlag {
718 HeadersFlag(0)
719 }
720
load(bits: u8) -> HeadersFlag721 pub fn load(bits: u8) -> HeadersFlag {
722 HeadersFlag(bits & ALL)
723 }
724
is_end_stream(&self) -> bool725 pub fn is_end_stream(&self) -> bool {
726 self.0 & END_STREAM == END_STREAM
727 }
728
set_end_stream(&mut self)729 pub fn set_end_stream(&mut self) {
730 self.0 |= END_STREAM;
731 }
732
is_end_headers(&self) -> bool733 pub fn is_end_headers(&self) -> bool {
734 self.0 & END_HEADERS == END_HEADERS
735 }
736
set_end_headers(&mut self)737 pub fn set_end_headers(&mut self) {
738 self.0 |= END_HEADERS;
739 }
740
is_padded(&self) -> bool741 pub fn is_padded(&self) -> bool {
742 self.0 & PADDED == PADDED
743 }
744
is_priority(&self) -> bool745 pub fn is_priority(&self) -> bool {
746 self.0 & PRIORITY == PRIORITY
747 }
748 }
749
750 impl Default for HeadersFlag {
751 /// Returns a `HeadersFlag` value with `END_HEADERS` set.
default() -> Self752 fn default() -> Self {
753 HeadersFlag(END_HEADERS)
754 }
755 }
756
757 impl From<HeadersFlag> for u8 {
from(src: HeadersFlag) -> u8758 fn from(src: HeadersFlag) -> u8 {
759 src.0
760 }
761 }
762
763 impl fmt::Debug for HeadersFlag {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result764 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
765 util::debug_flags(fmt, self.0)
766 .flag_if(self.is_end_headers(), "END_HEADERS")
767 .flag_if(self.is_end_stream(), "END_STREAM")
768 .flag_if(self.is_padded(), "PADDED")
769 .flag_if(self.is_priority(), "PRIORITY")
770 .finish()
771 }
772 }
773
774 // ===== impl PushPromiseFlag =====
775
776 impl PushPromiseFlag {
empty() -> PushPromiseFlag777 pub fn empty() -> PushPromiseFlag {
778 PushPromiseFlag(0)
779 }
780
load(bits: u8) -> PushPromiseFlag781 pub fn load(bits: u8) -> PushPromiseFlag {
782 PushPromiseFlag(bits & ALL)
783 }
784
is_end_headers(&self) -> bool785 pub fn is_end_headers(&self) -> bool {
786 self.0 & END_HEADERS == END_HEADERS
787 }
788
set_end_headers(&mut self)789 pub fn set_end_headers(&mut self) {
790 self.0 |= END_HEADERS;
791 }
792
is_padded(&self) -> bool793 pub fn is_padded(&self) -> bool {
794 self.0 & PADDED == PADDED
795 }
796 }
797
798 impl Default for PushPromiseFlag {
799 /// Returns a `PushPromiseFlag` value with `END_HEADERS` set.
default() -> Self800 fn default() -> Self {
801 PushPromiseFlag(END_HEADERS)
802 }
803 }
804
805 impl From<PushPromiseFlag> for u8 {
from(src: PushPromiseFlag) -> u8806 fn from(src: PushPromiseFlag) -> u8 {
807 src.0
808 }
809 }
810
811 impl fmt::Debug for PushPromiseFlag {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result812 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
813 util::debug_flags(fmt, self.0)
814 .flag_if(self.is_end_headers(), "END_HEADERS")
815 .flag_if(self.is_padded(), "PADDED")
816 .finish()
817 }
818 }
819
820 // ===== HeaderBlock =====
821
822 impl HeaderBlock {
load( &mut self, src: &mut BytesMut, max_header_list_size: usize, decoder: &mut hpack::Decoder, ) -> Result<(), Error>823 fn load(
824 &mut self,
825 src: &mut BytesMut,
826 max_header_list_size: usize,
827 decoder: &mut hpack::Decoder,
828 ) -> Result<(), Error> {
829 let mut reg = !self.fields.is_empty();
830 let mut malformed = false;
831 let mut headers_size = self.calculate_header_list_size();
832
833 macro_rules! set_pseudo {
834 ($field:ident, $val:expr) => {{
835 if reg {
836 tracing::trace!("load_hpack; header malformed -- pseudo not at head of block");
837 malformed = true;
838 } else if self.pseudo.$field.is_some() {
839 tracing::trace!("load_hpack; header malformed -- repeated pseudo");
840 malformed = true;
841 } else {
842 let __val = $val;
843 headers_size +=
844 decoded_header_size(stringify!($field).len() + 1, __val.as_str().len());
845 if headers_size < max_header_list_size {
846 self.pseudo.$field = Some(__val);
847 } else if !self.is_over_size {
848 tracing::trace!("load_hpack; header list size over max");
849 self.is_over_size = true;
850 }
851 }
852 }};
853 }
854
855 let mut cursor = Cursor::new(src);
856
857 // If the header frame is malformed, we still have to continue decoding
858 // the headers. A malformed header frame is a stream level error, but
859 // the hpack state is connection level. In order to maintain correct
860 // state for other streams, the hpack decoding process must complete.
861 let res = decoder.decode(&mut cursor, |header| {
862 use crate::hpack::Header::*;
863
864 match header {
865 Field { name, value } => {
866 // Connection level header fields are not supported and must
867 // result in a protocol error.
868
869 if name == header::CONNECTION
870 || name == header::TRANSFER_ENCODING
871 || name == header::UPGRADE
872 || name == "keep-alive"
873 || name == "proxy-connection"
874 {
875 tracing::trace!("load_hpack; connection level header");
876 malformed = true;
877 } else if name == header::TE && value != "trailers" {
878 tracing::trace!(
879 "load_hpack; TE header not set to trailers; val={:?}",
880 value
881 );
882 malformed = true;
883 } else {
884 reg = true;
885
886 headers_size += decoded_header_size(name.as_str().len(), value.len());
887 if headers_size < max_header_list_size {
888 self.fields.append(name, value);
889 } else if !self.is_over_size {
890 tracing::trace!("load_hpack; header list size over max");
891 self.is_over_size = true;
892 }
893 }
894 }
895 Authority(v) => set_pseudo!(authority, v),
896 Method(v) => set_pseudo!(method, v),
897 Scheme(v) => set_pseudo!(scheme, v),
898 Path(v) => set_pseudo!(path, v),
899 Protocol(v) => set_pseudo!(protocol, v),
900 Status(v) => set_pseudo!(status, v),
901 }
902 });
903
904 if let Err(e) = res {
905 tracing::trace!("hpack decoding error; err={:?}", e);
906 return Err(e.into());
907 }
908
909 if malformed {
910 tracing::trace!("malformed message");
911 return Err(Error::MalformedMessage);
912 }
913
914 Ok(())
915 }
916
into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock917 fn into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock {
918 let mut hpack = BytesMut::new();
919 let headers = Iter {
920 pseudo: Some(self.pseudo),
921 fields: self.fields.into_iter(),
922 };
923
924 encoder.encode(headers, &mut hpack);
925
926 EncodingHeaderBlock {
927 hpack: hpack.freeze(),
928 }
929 }
930
931 /// Calculates the size of the currently decoded header list.
932 ///
933 /// According to http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE
934 ///
935 /// > The value is based on the uncompressed size of header fields,
936 /// > including the length of the name and value in octets plus an
937 /// > overhead of 32 octets for each header field.
calculate_header_list_size(&self) -> usize938 fn calculate_header_list_size(&self) -> usize {
939 macro_rules! pseudo_size {
940 ($name:ident) => {{
941 self.pseudo
942 .$name
943 .as_ref()
944 .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
945 .unwrap_or(0)
946 }};
947 }
948
949 pseudo_size!(method)
950 + pseudo_size!(scheme)
951 + pseudo_size!(status)
952 + pseudo_size!(authority)
953 + pseudo_size!(path)
954 + self
955 .fields
956 .iter()
957 .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len()))
958 .sum::<usize>()
959 }
960 }
961
decoded_header_size(name: usize, value: usize) -> usize962 fn decoded_header_size(name: usize, value: usize) -> usize {
963 name + value + 32
964 }
965
966 #[cfg(test)]
967 mod test {
968 use std::iter::FromIterator;
969
970 use http::HeaderValue;
971
972 use super::*;
973 use crate::frame;
974 use crate::hpack::{huffman, Encoder};
975
976 #[test]
test_nameless_header_at_resume()977 fn test_nameless_header_at_resume() {
978 let mut encoder = Encoder::default();
979 let mut dst = BytesMut::new();
980
981 let headers = Headers::new(
982 StreamId::ZERO,
983 Default::default(),
984 HeaderMap::from_iter(vec![
985 (
986 HeaderName::from_static("hello"),
987 HeaderValue::from_static("world"),
988 ),
989 (
990 HeaderName::from_static("hello"),
991 HeaderValue::from_static("zomg"),
992 ),
993 (
994 HeaderName::from_static("hello"),
995 HeaderValue::from_static("sup"),
996 ),
997 ]),
998 );
999
1000 let continuation = headers
1001 .encode(&mut encoder, &mut (&mut dst).limit(frame::HEADER_LEN + 8))
1002 .unwrap();
1003
1004 assert_eq!(17, dst.len());
1005 assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]);
1006 assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]);
1007 assert_eq!("hello", huff_decode(&dst[11..15]));
1008 assert_eq!(0x80 | 4, dst[15]);
1009
1010 let mut world = dst[16..17].to_owned();
1011
1012 dst.clear();
1013
1014 assert!(continuation
1015 .encode(&mut (&mut dst).limit(frame::HEADER_LEN + 16))
1016 .is_none());
1017
1018 world.extend_from_slice(&dst[9..12]);
1019 assert_eq!("world", huff_decode(&world));
1020
1021 assert_eq!(24, dst.len());
1022 assert_eq!([0, 0, 15, 9, 4, 0, 0, 0, 0], &dst[0..9]);
1023
1024 // // Next is not indexed
1025 assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]);
1026 assert_eq!("zomg", huff_decode(&dst[15..18]));
1027 assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]);
1028 assert_eq!("sup", huff_decode(&dst[21..]));
1029 }
1030
huff_decode(src: &[u8]) -> BytesMut1031 fn huff_decode(src: &[u8]) -> BytesMut {
1032 let mut buf = BytesMut::new();
1033 huffman::decode(src, &mut buf).unwrap()
1034 }
1035 }
1036