1 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4 // option. This file may not be copied, modified, or distributed
5 // except according to those terms.
6 
7 // Encoding and decoding packets off the wire.
8 use crate::cid::{ConnectionId, ConnectionIdDecoder, ConnectionIdRef, MAX_CONNECTION_ID_LEN};
9 use crate::crypto::{CryptoDxState, CryptoSpace, CryptoStates};
10 use crate::{Error, Res};
11 
12 use neqo_common::{hex, hex_with_len, qtrace, qwarn, Decoder, Encoder};
13 use neqo_crypto::random;
14 
15 use std::cmp::min;
16 use std::convert::TryFrom;
17 use std::fmt;
18 use std::iter::ExactSizeIterator;
19 use std::ops::{Deref, DerefMut, Range};
20 use std::time::Instant;
21 
22 const PACKET_TYPE_INITIAL: u8 = 0x0;
23 const PACKET_TYPE_0RTT: u8 = 0x01;
24 const PACKET_TYPE_HANDSHAKE: u8 = 0x2;
25 const PACKET_TYPE_RETRY: u8 = 0x03;
26 
27 pub const PACKET_BIT_LONG: u8 = 0x80;
28 const PACKET_BIT_SHORT: u8 = 0x00;
29 const PACKET_BIT_FIXED_QUIC: u8 = 0x40;
30 const PACKET_BIT_SPIN: u8 = 0x20;
31 const PACKET_BIT_KEY_PHASE: u8 = 0x04;
32 
33 const PACKET_HP_MASK_LONG: u8 = 0x0f;
34 const PACKET_HP_MASK_SHORT: u8 = 0x1f;
35 
36 const SAMPLE_SIZE: usize = 16;
37 const SAMPLE_OFFSET: usize = 4;
38 const MAX_PACKET_NUMBER_LEN: usize = 4;
39 
40 mod retry;
41 
42 pub type PacketNumber = u64;
43 type Version = u32;
44 
45 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
46 pub enum PacketType {
47     VersionNegotiation,
48     Initial,
49     Handshake,
50     ZeroRtt,
51     Retry,
52     Short,
53     OtherVersion,
54 }
55 
56 impl PacketType {
57     #[must_use]
code(self) -> u858     fn code(self) -> u8 {
59         match self {
60             Self::Initial => PACKET_TYPE_INITIAL,
61             Self::ZeroRtt => PACKET_TYPE_0RTT,
62             Self::Handshake => PACKET_TYPE_HANDSHAKE,
63             Self::Retry => PACKET_TYPE_RETRY,
64             _ => panic!("shouldn't be here"),
65         }
66     }
67 }
68 
69 impl From<PacketType> for CryptoSpace {
from(v: PacketType) -> Self70     fn from(v: PacketType) -> Self {
71         match v {
72             PacketType::Initial => Self::Initial,
73             PacketType::ZeroRtt => Self::ZeroRtt,
74             PacketType::Handshake => Self::Handshake,
75             PacketType::Short => Self::ApplicationData,
76             _ => panic!("shouldn't be here"),
77         }
78     }
79 }
80 
81 impl From<CryptoSpace> for PacketType {
from(cs: CryptoSpace) -> Self82     fn from(cs: CryptoSpace) -> Self {
83         match cs {
84             CryptoSpace::Initial => Self::Initial,
85             CryptoSpace::ZeroRtt => Self::ZeroRtt,
86             CryptoSpace::Handshake => Self::Handshake,
87             CryptoSpace::ApplicationData => Self::Short,
88         }
89     }
90 }
91 
92 #[derive(Debug, Clone, Copy, PartialEq)]
93 pub enum QuicVersion {
94     Version1,
95     Draft29,
96     Draft30,
97     Draft31,
98     Draft32,
99 }
100 
101 impl QuicVersion {
as_u32(self) -> Version102     pub fn as_u32(self) -> Version {
103         match self {
104             Self::Version1 => 1,
105             Self::Draft29 => 0xff00_0000 + 29,
106             Self::Draft30 => 0xff00_0000 + 30,
107             Self::Draft31 => 0xff00_0000 + 31,
108             Self::Draft32 => 0xff00_0000 + 32,
109         }
110     }
111 }
112 
113 impl Default for QuicVersion {
default() -> Self114     fn default() -> Self {
115         Self::Version1
116     }
117 }
118 
119 impl TryFrom<Version> for QuicVersion {
120     type Error = Error;
121 
try_from(ver: Version) -> Res<Self>122     fn try_from(ver: Version) -> Res<Self> {
123         if ver == 1 {
124             Ok(Self::Version1)
125         } else if ver == 0xff00_0000 + 29 {
126             Ok(Self::Draft29)
127         } else if ver == 0xff00_0000 + 30 {
128             Ok(Self::Draft30)
129         } else if ver == 0xff00_0000 + 31 {
130             Ok(Self::Draft31)
131         } else if ver == 0xff00_0000 + 32 {
132             Ok(Self::Draft32)
133         } else {
134             Err(Error::VersionNegotiation)
135         }
136     }
137 }
138 
139 struct PacketBuilderOffsets {
140     /// The bits of the first octet that need masking.
141     first_byte_mask: u8,
142     /// The offset of the length field.
143     len: usize,
144     /// The location of the packet number field.
145     pn: Range<usize>,
146 }
147 
148 /// A packet builder that can be used to produce short packets and long packets.
149 /// This does not produce Retry or Version Negotiation.
150 pub struct PacketBuilder {
151     encoder: Encoder,
152     pn: PacketNumber,
153     header: Range<usize>,
154     offsets: PacketBuilderOffsets,
155     limit: usize,
156     /// Whether to pad the packet before construction.
157     padding: bool,
158 }
159 
160 impl PacketBuilder {
161     /// The minimum useful frame size.  If space is less than this, we will claim to be full.
162     pub const MINIMUM_FRAME_SIZE: usize = 2;
163 
infer_limit(encoder: &Encoder) -> usize164     fn infer_limit(encoder: &Encoder) -> usize {
165         if encoder.capacity() > 64 {
166             encoder.capacity()
167         } else {
168             2048
169         }
170     }
171 
172     /// Start building a short header packet.
173     ///
174     /// This doesn't fail if there isn't enough space; instead it returns a builder that
175     /// has no available space left.  This allows the caller to extract the encoder
176     /// and any packets that might have been added before as adding a packet header is
177     /// only likely to fail if there are other packets already written.
178     ///
179     /// If, after calling this method, `remaining()` returns 0, then call `abort()` to get
180     /// the encoder back.
181     #[allow(clippy::reversed_empty_ranges)]
short(mut encoder: Encoder, key_phase: bool, dcid: impl AsRef<[u8]>) -> Self182     pub fn short(mut encoder: Encoder, key_phase: bool, dcid: impl AsRef<[u8]>) -> Self {
183         let mut limit = Self::infer_limit(&encoder);
184         let header_start = encoder.len();
185         // Check that there is enough space for the header.
186         // 5 = 1 (first byte) + 4 (packet number)
187         if limit > encoder.len() && 5 + dcid.as_ref().len() < limit - encoder.len() {
188             encoder
189                 .encode_byte(PACKET_BIT_SHORT | PACKET_BIT_FIXED_QUIC | (u8::from(key_phase) << 2));
190             encoder.encode(dcid.as_ref());
191         } else {
192             limit = 0;
193         }
194         Self {
195             encoder,
196             pn: u64::max_value(),
197             header: header_start..header_start,
198             offsets: PacketBuilderOffsets {
199                 first_byte_mask: PACKET_HP_MASK_SHORT,
200                 pn: 0..0,
201                 len: 0,
202             },
203             limit,
204             padding: false,
205         }
206     }
207 
208     /// Start building a long header packet.
209     /// For an Initial packet you will need to call initial_token(),
210     /// even if the token is empty.
211     ///
212     /// See `short()` for more on how to handle this in cases where there is no space.
213     #[allow(clippy::reversed_empty_ranges)] // For initializing an empty range.
long( mut encoder: Encoder, pt: PacketType, quic_version: QuicVersion, dcid: impl AsRef<[u8]>, scid: impl AsRef<[u8]>, ) -> Self214     pub fn long(
215         mut encoder: Encoder,
216         pt: PacketType,
217         quic_version: QuicVersion,
218         dcid: impl AsRef<[u8]>,
219         scid: impl AsRef<[u8]>,
220     ) -> Self {
221         let mut limit = Self::infer_limit(&encoder);
222         let header_start = encoder.len();
223         // Check that there is enough space for the header.
224         // 11 = 1 (first byte) + 4 (version) + 2 (dcid+scid length) + 4 (packet number)
225         if limit > encoder.len()
226             && 11 + dcid.as_ref().len() + scid.as_ref().len() < limit - encoder.len()
227         {
228             encoder.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC | pt.code() << 4);
229             encoder.encode_uint(4, quic_version.as_u32());
230             encoder.encode_vec(1, dcid.as_ref());
231             encoder.encode_vec(1, scid.as_ref());
232         } else {
233             limit = 0;
234         }
235 
236         Self {
237             encoder,
238             pn: u64::max_value(),
239             header: header_start..header_start,
240             offsets: PacketBuilderOffsets {
241                 first_byte_mask: PACKET_HP_MASK_LONG,
242                 pn: 0..0,
243                 len: 0,
244             },
245             limit,
246             padding: false,
247         }
248     }
249 
is_long(&self) -> bool250     fn is_long(&self) -> bool {
251         self[self.header.start] & 0x80 == PACKET_BIT_LONG
252     }
253 
254     /// This stores a value that can be used as a limit.  This does not cause
255     /// this limit to be enforced until encryption occurs.  Prior to that, it
256     /// is only used voluntarily by users of the builder, through `remaining()`.
set_limit(&mut self, limit: usize)257     pub fn set_limit(&mut self, limit: usize) {
258         self.limit = limit;
259     }
260 
261     /// Get the current limit.
262     #[must_use]
limit(&mut self) -> usize263     pub fn limit(&mut self) -> usize {
264         self.limit
265     }
266 
267     /// How many bytes remain against the size limit for the builder.
268     #[must_use]
remaining(&self) -> usize269     pub fn remaining(&self) -> usize {
270         self.limit.saturating_sub(self.encoder.len())
271     }
272 
273     /// Returns true if the packet has no more space for frames.
274     #[must_use]
is_full(&self) -> bool275     pub fn is_full(&self) -> bool {
276         // No useful frame is smaller than 2 bytes long.
277         self.limit < self.encoder.len() + Self::MINIMUM_FRAME_SIZE
278     }
279 
280     /// Adjust the limit to ensure that no more data is added.
mark_full(&mut self)281     pub fn mark_full(&mut self) {
282         self.limit = self.encoder.len()
283     }
284 
285     /// Mark the packet as needing padding (or not).
enable_padding(&mut self, needs_padding: bool)286     pub fn enable_padding(&mut self, needs_padding: bool) {
287         self.padding = needs_padding;
288     }
289 
290     /// Maybe pad with "PADDING" frames.
291     /// Only does so if padding was needed and this is a short packet.
292     /// Returns true if padding was added.
pad(&mut self) -> bool293     pub fn pad(&mut self) -> bool {
294         if self.padding && !self.is_long() {
295             self.encoder.pad_to(self.limit, 0);
296             true
297         } else {
298             false
299         }
300     }
301 
302     /// Add unpredictable values for unprotected parts of the packet.
scramble(&mut self, quic_bit: bool)303     pub fn scramble(&mut self, quic_bit: bool) {
304         debug_assert!(self.len() > self.header.start);
305         let mask = if quic_bit { PACKET_BIT_FIXED_QUIC } else { 0 }
306             | if self.is_long() { 0 } else { PACKET_BIT_SPIN };
307         let first = self.header.start;
308         self[first] ^= random(1)[0] & mask;
309     }
310 
311     /// For an Initial packet, encode the token.
312     /// If you fail to do this, then you will not get a valid packet.
initial_token(&mut self, token: &[u8])313     pub fn initial_token(&mut self, token: &[u8]) {
314         debug_assert_eq!(
315             self.encoder[self.header.start] & 0xb0,
316             PACKET_BIT_LONG | PACKET_TYPE_INITIAL << 4
317         );
318         if Encoder::vvec_len(token.len()) < self.remaining() {
319             self.encoder.encode_vvec(token);
320         } else {
321             self.limit = 0;
322         }
323     }
324 
325     /// Add a packet number of the given size.
326     /// For a long header packet, this also inserts a dummy length.
327     /// The length is filled in after calling `build`.
328     /// Does nothing if there isn't 4 bytes available other than render this builder
329     /// unusable; if `remaining()` returns 0 at any point, call `abort()`.
pn(&mut self, pn: PacketNumber, pn_len: usize)330     pub fn pn(&mut self, pn: PacketNumber, pn_len: usize) {
331         if self.remaining() < 4 {
332             self.limit = 0;
333             return;
334         }
335 
336         // Reserve space for a length in long headers.
337         if self.is_long() {
338             self.offsets.len = self.encoder.len();
339             self.encoder.encode(&[0; 2]);
340         }
341 
342         // This allows the input to be >4, which is absurd, but we can eat that.
343         let pn_len = min(MAX_PACKET_NUMBER_LEN, pn_len);
344         debug_assert_ne!(pn_len, 0);
345         // Encode the packet number and save its offset.
346         let pn_offset = self.encoder.len();
347         self.encoder.encode_uint(pn_len, pn);
348         self.offsets.pn = pn_offset..self.encoder.len();
349 
350         // Now encode the packet number length and save the header length.
351         self.encoder[self.header.start] |= u8::try_from(pn_len - 1).unwrap();
352         self.header.end = self.encoder.len();
353         self.pn = pn;
354     }
355 
write_len(&mut self, expansion: usize)356     fn write_len(&mut self, expansion: usize) {
357         let len = self.encoder.len() - (self.offsets.len + 2) + expansion;
358         self.encoder[self.offsets.len] = 0x40 | ((len >> 8) & 0x3f) as u8;
359         self.encoder[self.offsets.len + 1] = (len & 0xff) as u8;
360     }
361 
pad_for_crypto(&mut self, crypto: &mut CryptoDxState)362     fn pad_for_crypto(&mut self, crypto: &mut CryptoDxState) {
363         // Make sure that there is enough data in the packet.
364         // The length of the packet number plus the payload length needs to
365         // be at least 4 (MAX_PACKET_NUMBER_LEN) plus any amount by which
366         // the header protection sample exceeds the AEAD expansion.
367         let crypto_pad = crypto.extra_padding();
368         self.encoder.pad_to(
369             self.offsets.pn.start + MAX_PACKET_NUMBER_LEN + crypto_pad,
370             0,
371         );
372     }
373 
374     /// A lot of frames here are just a collection of varints.
375     /// This helper functions writes a frame like that safely, returning `true` if
376     /// a frame was written.
write_varint_frame(&mut self, values: &[u64]) -> bool377     pub fn write_varint_frame(&mut self, values: &[u64]) -> bool {
378         let write = self.remaining()
379             >= values
380                 .iter()
381                 .map(|&v| Encoder::varint_len(v))
382                 .sum::<usize>();
383         if write {
384             for v in values {
385                 self.encode_varint(*v);
386             }
387             debug_assert!(self.len() <= self.limit());
388         };
389         write
390     }
391 
392     /// Build the packet and return the encoder.
build(mut self, crypto: &mut CryptoDxState) -> Res<Encoder>393     pub fn build(mut self, crypto: &mut CryptoDxState) -> Res<Encoder> {
394         if self.len() > self.limit {
395             qwarn!("Packet contents are more than the limit");
396             debug_assert!(false);
397             return Err(Error::InternalError(5));
398         }
399 
400         self.pad_for_crypto(crypto);
401         if self.offsets.len > 0 {
402             self.write_len(crypto.expansion());
403         }
404 
405         let hdr = &self.encoder[self.header.clone()];
406         let body = &self.encoder[self.header.end..];
407         qtrace!(
408             "Packet build pn={} hdr={} body={}",
409             self.pn,
410             hex(hdr),
411             hex(body)
412         );
413         let ciphertext = crypto.encrypt(self.pn, hdr, body)?;
414 
415         // Calculate the mask.
416         let offset = SAMPLE_OFFSET - self.offsets.pn.len();
417         assert!(offset + SAMPLE_SIZE <= ciphertext.len());
418         let sample = &ciphertext[offset..offset + SAMPLE_SIZE];
419         let mask = crypto.compute_mask(sample)?;
420 
421         // Apply the mask.
422         self.encoder[self.header.start] ^= mask[0] & self.offsets.first_byte_mask;
423         for (i, j) in (1..=self.offsets.pn.len()).zip(self.offsets.pn) {
424             self.encoder[j] ^= mask[i];
425         }
426 
427         // Finally, cut off the plaintext and add back the ciphertext.
428         self.encoder.truncate(self.header.end);
429         self.encoder.encode(&ciphertext);
430         qtrace!("Packet built {}", hex(&self.encoder));
431         Ok(self.encoder)
432     }
433 
434     /// Abort writing of this packet and return the encoder.
435     #[must_use]
abort(mut self) -> Encoder436     pub fn abort(mut self) -> Encoder {
437         self.encoder.truncate(self.header.start);
438         self.encoder
439     }
440 
441     /// Work out if nothing was added after the header.
442     #[must_use]
packet_empty(&self) -> bool443     pub fn packet_empty(&self) -> bool {
444         self.encoder.len() == self.header.end
445     }
446 
447     /// Make a retry packet.
448     /// As this is a simple packet, this is just an associated function.
449     /// As Retry is odd (it has to be constructed with leading bytes),
450     /// this returns a Vec<u8> rather than building on an encoder.
retry( quic_version: QuicVersion, dcid: &[u8], scid: &[u8], token: &[u8], odcid: &[u8], ) -> Res<Vec<u8>>451     pub fn retry(
452         quic_version: QuicVersion,
453         dcid: &[u8],
454         scid: &[u8],
455         token: &[u8],
456         odcid: &[u8],
457     ) -> Res<Vec<u8>> {
458         let mut encoder = Encoder::default();
459         encoder.encode_vec(1, odcid);
460         let start = encoder.len();
461         encoder.encode_byte(
462             PACKET_BIT_LONG
463                 | PACKET_BIT_FIXED_QUIC
464                 | (PACKET_TYPE_RETRY << 4)
465                 | (random(1)[0] & 0xf),
466         );
467         encoder.encode_uint(4, quic_version.as_u32());
468         encoder.encode_vec(1, dcid);
469         encoder.encode_vec(1, scid);
470         debug_assert_ne!(token.len(), 0);
471         encoder.encode(token);
472         let tag = retry::use_aead(quic_version, |aead| {
473             let mut buf = vec![0; aead.expansion()];
474             Ok(aead.encrypt(0, &encoder, &[], &mut buf)?.to_vec())
475         })?;
476         encoder.encode(&tag);
477         let mut complete: Vec<u8> = encoder.into();
478         Ok(complete.split_off(start))
479     }
480 
481     /// Make a Version Negotiation packet.
version_negotiation(dcid: &[u8], scid: &[u8]) -> Vec<u8>482     pub fn version_negotiation(dcid: &[u8], scid: &[u8]) -> Vec<u8> {
483         let mut encoder = Encoder::default();
484         let mut grease = random(5);
485         // This will not include the "QUIC bit" sometimes.  Intentionally.
486         encoder.encode_byte(PACKET_BIT_LONG | (grease[4] & 0x7f));
487         encoder.encode(&[0; 4]); // Zero version == VN.
488         encoder.encode_vec(1, dcid);
489         encoder.encode_vec(1, scid);
490         encoder.encode_uint(4, QuicVersion::Version1.as_u32());
491         encoder.encode_uint(4, QuicVersion::Draft29.as_u32());
492         encoder.encode_uint(4, QuicVersion::Draft30.as_u32());
493         encoder.encode_uint(4, QuicVersion::Draft31.as_u32());
494         encoder.encode_uint(4, QuicVersion::Draft32.as_u32());
495         // Add a greased version, using the randomness already generated.
496         for g in &mut grease[..4] {
497             *g = *g & 0xf0 | 0x0a;
498         }
499         encoder.encode(&grease[0..4]);
500         encoder.into()
501     }
502 }
503 
504 impl Deref for PacketBuilder {
505     type Target = Encoder;
506 
deref(&self) -> &Self::Target507     fn deref(&self) -> &Self::Target {
508         &self.encoder
509     }
510 }
511 
512 impl DerefMut for PacketBuilder {
deref_mut(&mut self) -> &mut Self::Target513     fn deref_mut(&mut self) -> &mut Self::Target {
514         &mut self.encoder
515     }
516 }
517 
518 impl From<PacketBuilder> for Encoder {
from(v: PacketBuilder) -> Self519     fn from(v: PacketBuilder) -> Self {
520         v.encoder
521     }
522 }
523 
524 /// PublicPacket holds information from packets that is public only.  This allows for
525 /// processing of packets prior to decryption.
526 pub struct PublicPacket<'a> {
527     /// The packet type.
528     packet_type: PacketType,
529     /// The recovered destination connection ID.
530     dcid: ConnectionIdRef<'a>,
531     /// The source connection ID, if this is a long header packet.
532     scid: Option<ConnectionIdRef<'a>>,
533     /// Any token that is included in the packet (Retry always has a token; Initial sometimes does).
534     /// This is empty when there is no token.
535     token: &'a [u8],
536     /// The size of the header, not including the packet number.
537     header_len: usize,
538     /// Protocol version, if present in header.
539     quic_version: Option<QuicVersion>,
540     /// A reference to the entire packet, including the header.
541     data: &'a [u8],
542 }
543 
544 impl<'a> PublicPacket<'a> {
opt<T>(v: Option<T>) -> Res<T>545     fn opt<T>(v: Option<T>) -> Res<T> {
546         if let Some(v) = v {
547             Ok(v)
548         } else {
549             Err(Error::NoMoreData)
550         }
551     }
552 
553     /// Decode the type-specific portions of a long header.
554     /// This includes reading the length and the remainder of the packet.
555     /// Returns a tuple of any token and the length of the header.
decode_long( decoder: &mut Decoder<'a>, packet_type: PacketType, quic_version: QuicVersion, ) -> Res<(&'a [u8], usize)>556     fn decode_long(
557         decoder: &mut Decoder<'a>,
558         packet_type: PacketType,
559         quic_version: QuicVersion,
560     ) -> Res<(&'a [u8], usize)> {
561         if packet_type == PacketType::Retry {
562             let header_len = decoder.offset();
563             let expansion = retry::expansion(quic_version);
564             let token = Self::opt(decoder.decode(decoder.remaining() - expansion))?;
565             if token.is_empty() {
566                 return Err(Error::InvalidPacket);
567             }
568             Self::opt(decoder.decode(expansion))?;
569             return Ok((token, header_len));
570         }
571         let token = if packet_type == PacketType::Initial {
572             Self::opt(decoder.decode_vvec())?
573         } else {
574             &[]
575         };
576         let len = Self::opt(decoder.decode_varint())?;
577         let header_len = decoder.offset();
578         let _body = Self::opt(decoder.decode(usize::try_from(len)?))?;
579         Ok((token, header_len))
580     }
581 
582     /// Decode the common parts of a packet.  This provides minimal parsing and validation.
583     /// Returns a tuple of a `PublicPacket` and a slice with any remainder from the datagram.
decode(data: &'a [u8], dcid_decoder: &dyn ConnectionIdDecoder) -> Res<(Self, &'a [u8])>584     pub fn decode(data: &'a [u8], dcid_decoder: &dyn ConnectionIdDecoder) -> Res<(Self, &'a [u8])> {
585         let mut decoder = Decoder::new(data);
586         let first = Self::opt(decoder.decode_byte())?;
587 
588         if first & 0x80 == PACKET_BIT_SHORT {
589             // Conveniently, this also guarantees that there is enough space
590             // for a connection ID of any size.
591             if decoder.remaining() < SAMPLE_OFFSET + SAMPLE_SIZE {
592                 return Err(Error::InvalidPacket);
593             }
594             let dcid = Self::opt(dcid_decoder.decode_cid(&mut decoder))?;
595             if decoder.remaining() < SAMPLE_OFFSET + SAMPLE_SIZE {
596                 return Err(Error::InvalidPacket);
597             }
598             let header_len = decoder.offset();
599             return Ok((
600                 Self {
601                     packet_type: PacketType::Short,
602                     dcid,
603                     scid: None,
604                     token: &[],
605                     header_len,
606                     quic_version: None,
607                     data,
608                 },
609                 &[],
610             ));
611         }
612 
613         // Generic long header.
614         let version = Version::try_from(Self::opt(decoder.decode_uint(4))?).unwrap();
615         let dcid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?);
616         let scid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?);
617 
618         // Version negotiation.
619         if version == 0 {
620             return Ok((
621                 Self {
622                     packet_type: PacketType::VersionNegotiation,
623                     dcid,
624                     scid: Some(scid),
625                     token: &[],
626                     header_len: decoder.offset(),
627                     quic_version: None,
628                     data,
629                 },
630                 &[],
631             ));
632         }
633 
634         // Check that this is a long header from a supported version.
635         let quic_version = if let Ok(v) = QuicVersion::try_from(version) {
636             v
637         } else {
638             return Ok((
639                 Self {
640                     packet_type: PacketType::OtherVersion,
641                     dcid,
642                     scid: Some(scid),
643                     token: &[],
644                     header_len: decoder.offset(),
645                     quic_version: None,
646                     data,
647                 },
648                 &[],
649             ));
650         };
651 
652         if dcid.len() > MAX_CONNECTION_ID_LEN || scid.len() > MAX_CONNECTION_ID_LEN {
653             return Err(Error::InvalidPacket);
654         }
655         let packet_type = match (first >> 4) & 3 {
656             PACKET_TYPE_INITIAL => PacketType::Initial,
657             PACKET_TYPE_0RTT => PacketType::ZeroRtt,
658             PACKET_TYPE_HANDSHAKE => PacketType::Handshake,
659             PACKET_TYPE_RETRY => PacketType::Retry,
660             _ => unreachable!(),
661         };
662 
663         // The type-specific code includes a token.  This consumes the remainder of the packet.
664         let (token, header_len) = Self::decode_long(&mut decoder, packet_type, quic_version)?;
665         let end = data.len() - decoder.remaining();
666         let (data, remainder) = data.split_at(end);
667         Ok((
668             Self {
669                 packet_type,
670                 dcid,
671                 scid: Some(scid),
672                 token,
673                 header_len,
674                 quic_version: Some(quic_version),
675                 data,
676             },
677             remainder,
678         ))
679     }
680 
681     /// Validate the given packet as though it were a retry.
is_valid_retry(&self, odcid: &ConnectionId) -> bool682     pub fn is_valid_retry(&self, odcid: &ConnectionId) -> bool {
683         if self.packet_type != PacketType::Retry {
684             return false;
685         }
686         let version = self.quic_version.unwrap();
687         let expansion = retry::expansion(version);
688         if self.data.len() <= expansion {
689             return false;
690         }
691         let (header, tag) = self.data.split_at(self.data.len() - expansion);
692         let mut encoder = Encoder::with_capacity(self.data.len());
693         encoder.encode_vec(1, odcid);
694         encoder.encode(header);
695         retry::use_aead(version, |aead| {
696             let mut buf = vec![0; expansion];
697             Ok(aead.decrypt(0, &encoder, tag, &mut buf)?.is_empty())
698         })
699         .unwrap_or(false)
700     }
701 
is_valid_initial(&self) -> bool702     pub fn is_valid_initial(&self) -> bool {
703         // Packet has to be an initial, with a DCID of 8 bytes, or a token.
704         // Note: the Server class validates the token and checks the length.
705         self.packet_type == PacketType::Initial
706             && (self.dcid().len() >= 8 || !self.token.is_empty())
707     }
708 
packet_type(&self) -> PacketType709     pub fn packet_type(&self) -> PacketType {
710         self.packet_type
711     }
712 
dcid(&self) -> &ConnectionIdRef<'a>713     pub fn dcid(&self) -> &ConnectionIdRef<'a> {
714         &self.dcid
715     }
716 
scid(&self) -> &ConnectionIdRef<'a>717     pub fn scid(&self) -> &ConnectionIdRef<'a> {
718         self.scid
719             .as_ref()
720             .expect("should only be called for long header packets")
721     }
722 
token(&self) -> &'a [u8]723     pub fn token(&self) -> &'a [u8] {
724         self.token
725     }
726 
version(&self) -> Option<QuicVersion>727     pub fn version(&self) -> Option<QuicVersion> {
728         self.quic_version
729     }
730 
len(&self) -> usize731     pub fn len(&self) -> usize {
732         self.data.len()
733     }
734 
decode_pn(expected: PacketNumber, pn: u64, w: usize) -> PacketNumber735     fn decode_pn(expected: PacketNumber, pn: u64, w: usize) -> PacketNumber {
736         let window = 1_u64 << (w * 8);
737         let candidate = (expected & !(window - 1)) | pn;
738         if candidate + (window / 2) <= expected {
739             candidate + window
740         } else if candidate > expected + (window / 2) {
741             match candidate.checked_sub(window) {
742                 Some(pn_sub) => pn_sub,
743                 None => candidate,
744             }
745         } else {
746             candidate
747         }
748     }
749 
750     /// Decrypt the header of the packet.
decrypt_header( &self, crypto: &mut CryptoDxState, ) -> Res<(bool, PacketNumber, Vec<u8>, &'a [u8])>751     fn decrypt_header(
752         &self,
753         crypto: &mut CryptoDxState,
754     ) -> Res<(bool, PacketNumber, Vec<u8>, &'a [u8])> {
755         assert_ne!(self.packet_type, PacketType::Retry);
756         assert_ne!(self.packet_type, PacketType::VersionNegotiation);
757 
758         qtrace!(
759             "unmask hdr={}",
760             hex(&self.data[..self.header_len + SAMPLE_OFFSET])
761         );
762 
763         let sample_offset = self.header_len + SAMPLE_OFFSET;
764         let mask = if let Some(sample) = self.data.get(sample_offset..(sample_offset + SAMPLE_SIZE))
765         {
766             crypto.compute_mask(sample)
767         } else {
768             Err(Error::NoMoreData)
769         }?;
770 
771         // Un-mask the leading byte.
772         let bits = if self.packet_type == PacketType::Short {
773             PACKET_HP_MASK_SHORT
774         } else {
775             PACKET_HP_MASK_LONG
776         };
777         let first_byte = self.data[0] ^ (mask[0] & bits);
778 
779         // Make a copy of the header to work on.
780         let mut hdrbytes = self.data[..self.header_len + 4].to_vec();
781         hdrbytes[0] = first_byte;
782 
783         // Unmask the PN.
784         let mut pn_encoded: u64 = 0;
785         for i in 0..MAX_PACKET_NUMBER_LEN {
786             hdrbytes[self.header_len + i] ^= mask[1 + i];
787             pn_encoded <<= 8;
788             pn_encoded += u64::from(hdrbytes[self.header_len + i]);
789         }
790 
791         // Now decode the packet number length and apply it, hopefully in constant time.
792         let pn_len = usize::from((first_byte & 0x3) + 1);
793         hdrbytes.truncate(self.header_len + pn_len);
794         pn_encoded >>= 8 * (MAX_PACKET_NUMBER_LEN - pn_len);
795 
796         qtrace!("unmasked hdr={}", hex(&hdrbytes));
797 
798         let key_phase = self.packet_type == PacketType::Short
799             && (first_byte & PACKET_BIT_KEY_PHASE) == PACKET_BIT_KEY_PHASE;
800         let pn = Self::decode_pn(crypto.next_pn(), pn_encoded, pn_len);
801         Ok((
802             key_phase,
803             pn,
804             hdrbytes,
805             &self.data[self.header_len + pn_len..],
806         ))
807     }
808 
decrypt(&self, crypto: &mut CryptoStates, release_at: Instant) -> Res<DecryptedPacket>809     pub fn decrypt(&self, crypto: &mut CryptoStates, release_at: Instant) -> Res<DecryptedPacket> {
810         let cspace: CryptoSpace = self.packet_type.into();
811         // This has to work in two stages because we need to remove header protection
812         // before picking the keys to use.
813         if let Some(rx) = crypto.rx_hp(cspace) {
814             // Note that this will dump early, which creates a side-channel.
815             // This is OK in this case because we the only reason this can
816             // fail is if the cryptographic module is bad or the packet is
817             // too small (which is public information).
818             let (key_phase, pn, header, body) = self.decrypt_header(rx)?;
819             qtrace!([rx], "decoded header: {:?}", header);
820             let rx = crypto.rx(cspace, key_phase).unwrap();
821             let d = rx.decrypt(pn, &header, body)?;
822             // If this is the first packet ever successfully decrypted
823             // using `rx`, make sure to initiate a key update.
824             if rx.needs_update() {
825                 crypto.key_update_received(release_at)?;
826             }
827             crypto.check_pn_overlap()?;
828             Ok(DecryptedPacket {
829                 pt: self.packet_type,
830                 pn,
831                 data: d,
832             })
833         } else if crypto.rx_pending(cspace) {
834             Err(Error::KeysPending(cspace))
835         } else {
836             qtrace!("keys for {:?} already discarded", cspace);
837             Err(Error::KeysDiscarded)
838         }
839     }
840 
supported_versions(&self) -> Res<Vec<Version>>841     pub fn supported_versions(&self) -> Res<Vec<Version>> {
842         assert_eq!(self.packet_type, PacketType::VersionNegotiation);
843         let mut decoder = Decoder::new(&self.data[self.header_len..]);
844         let mut res = Vec::new();
845         while decoder.remaining() > 0 {
846             let version = Version::try_from(Self::opt(decoder.decode_uint(4))?)?;
847             res.push(version);
848         }
849         Ok(res)
850     }
851 }
852 
853 impl fmt::Debug for PublicPacket<'_> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result854     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
855         write!(
856             f,
857             "{:?}: {} {}",
858             self.packet_type(),
859             hex_with_len(&self.data[..self.header_len]),
860             hex_with_len(&self.data[self.header_len..])
861         )
862     }
863 }
864 
865 pub struct DecryptedPacket {
866     pt: PacketType,
867     pn: PacketNumber,
868     data: Vec<u8>,
869 }
870 
871 impl DecryptedPacket {
packet_type(&self) -> PacketType872     pub fn packet_type(&self) -> PacketType {
873         self.pt
874     }
875 
pn(&self) -> PacketNumber876     pub fn pn(&self) -> PacketNumber {
877         self.pn
878     }
879 }
880 
881 impl Deref for DecryptedPacket {
882     type Target = [u8];
883 
deref(&self) -> &Self::Target884     fn deref(&self) -> &Self::Target {
885         &self.data[..]
886     }
887 }
888 
889 #[cfg(all(test, not(feature = "fuzzing")))]
890 mod tests {
891     use super::*;
892     use crate::crypto::{CryptoDxState, CryptoStates};
893     use crate::{EmptyConnectionIdGenerator, QuicVersion, RandomConnectionIdGenerator};
894     use neqo_common::Encoder;
895     use test_fixture::{fixture_init, now};
896 
897     const CLIENT_CID: &[u8] = &[0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08];
898     const SERVER_CID: &[u8] = &[0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5];
899 
900     /// This is a connection ID manager, which is only used for decoding short header packets.
cid_mgr() -> RandomConnectionIdGenerator901     fn cid_mgr() -> RandomConnectionIdGenerator {
902         RandomConnectionIdGenerator::new(SERVER_CID.len())
903     }
904 
905     const SAMPLE_INITIAL_PAYLOAD: &[u8] = &[
906         0x02, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x40, 0x5a, 0x02, 0x00, 0x00, 0x56, 0x03, 0x03,
907         0xee, 0xfc, 0xe7, 0xf7, 0xb3, 0x7b, 0xa1, 0xd1, 0x63, 0x2e, 0x96, 0x67, 0x78, 0x25, 0xdd,
908         0xf7, 0x39, 0x88, 0xcf, 0xc7, 0x98, 0x25, 0xdf, 0x56, 0x6d, 0xc5, 0x43, 0x0b, 0x9a, 0x04,
909         0x5a, 0x12, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24, 0x00, 0x1d, 0x00,
910         0x20, 0x9d, 0x3c, 0x94, 0x0d, 0x89, 0x69, 0x0b, 0x84, 0xd0, 0x8a, 0x60, 0x99, 0x3c, 0x14,
911         0x4e, 0xca, 0x68, 0x4d, 0x10, 0x81, 0x28, 0x7c, 0x83, 0x4d, 0x53, 0x11, 0xbc, 0xf3, 0x2b,
912         0xb9, 0xda, 0x1a, 0x00, 0x2b, 0x00, 0x02, 0x03, 0x04,
913     ];
914     const SAMPLE_INITIAL: &[u8] = &[
915         0xcf, 0x00, 0x00, 0x00, 0x01, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
916         0x00, 0x40, 0x75, 0xc0, 0xd9, 0x5a, 0x48, 0x2c, 0xd0, 0x99, 0x1c, 0xd2, 0x5b, 0x0a, 0xac,
917         0x40, 0x6a, 0x58, 0x16, 0xb6, 0x39, 0x41, 0x00, 0xf3, 0x7a, 0x1c, 0x69, 0x79, 0x75, 0x54,
918         0x78, 0x0b, 0xb3, 0x8c, 0xc5, 0xa9, 0x9f, 0x5e, 0xde, 0x4c, 0xf7, 0x3c, 0x3e, 0xc2, 0x49,
919         0x3a, 0x18, 0x39, 0xb3, 0xdb, 0xcb, 0xa3, 0xf6, 0xea, 0x46, 0xc5, 0xb7, 0x68, 0x4d, 0xf3,
920         0x54, 0x8e, 0x7d, 0xde, 0xb9, 0xc3, 0xbf, 0x9c, 0x73, 0xcc, 0x3f, 0x3b, 0xde, 0xd7, 0x4b,
921         0x56, 0x2b, 0xfb, 0x19, 0xfb, 0x84, 0x02, 0x2f, 0x8e, 0xf4, 0xcd, 0xd9, 0x37, 0x95, 0xd7,
922         0x7d, 0x06, 0xed, 0xbb, 0x7a, 0xaf, 0x2f, 0x58, 0x89, 0x18, 0x50, 0xab, 0xbd, 0xca, 0x3d,
923         0x20, 0x39, 0x8c, 0x27, 0x64, 0x56, 0xcb, 0xc4, 0x21, 0x58, 0x40, 0x7d, 0xd0, 0x74, 0xee,
924     ];
925 
926     #[test]
sample_server_initial()927     fn sample_server_initial() {
928         fixture_init();
929         let mut prot = CryptoDxState::test_default();
930 
931         // The spec uses PN=1, but our crypto refuses to skip packet numbers.
932         // So burn an encryption:
933         let burn = prot.encrypt(0, &[], &[]).expect("burn OK");
934         assert_eq!(burn.len(), prot.expansion());
935 
936         let mut builder = PacketBuilder::long(
937             Encoder::new(),
938             PacketType::Initial,
939             QuicVersion::default(),
940             &ConnectionId::from(&[][..]),
941             &ConnectionId::from(SERVER_CID),
942         );
943         builder.initial_token(&[]);
944         builder.pn(1, 2);
945         builder.encode(SAMPLE_INITIAL_PAYLOAD);
946         let packet = builder.build(&mut prot).expect("build");
947         assert_eq!(&packet[..], SAMPLE_INITIAL);
948     }
949 
950     #[test]
decrypt_initial()951     fn decrypt_initial() {
952         const EXTRA: &[u8] = &[0xce; 33];
953 
954         fixture_init();
955         let mut padded = SAMPLE_INITIAL.to_vec();
956         padded.extend_from_slice(EXTRA);
957         let (packet, remainder) = PublicPacket::decode(&padded, &cid_mgr()).unwrap();
958         assert_eq!(packet.packet_type(), PacketType::Initial);
959         assert_eq!(&packet.dcid()[..], &[] as &[u8]);
960         assert_eq!(&packet.scid()[..], SERVER_CID);
961         assert!(packet.token().is_empty());
962         assert_eq!(remainder, EXTRA);
963 
964         let decrypted = packet
965             .decrypt(&mut CryptoStates::test_default(), now())
966             .unwrap();
967         assert_eq!(decrypted.pn(), 1);
968     }
969 
970     #[test]
disallow_long_dcid()971     fn disallow_long_dcid() {
972         let mut enc = Encoder::new();
973         enc.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC);
974         enc.encode_uint(4, QuicVersion::default().as_u32());
975         enc.encode_vec(1, &[0x00; MAX_CONNECTION_ID_LEN + 1]);
976         enc.encode_vec(1, &[]);
977         enc.encode(&[0xff; 40]); // junk
978 
979         assert!(PublicPacket::decode(&enc, &cid_mgr()).is_err());
980     }
981 
982     #[test]
disallow_long_scid()983     fn disallow_long_scid() {
984         let mut enc = Encoder::new();
985         enc.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC);
986         enc.encode_uint(4, QuicVersion::default().as_u32());
987         enc.encode_vec(1, &[]);
988         enc.encode_vec(1, &[0x00; MAX_CONNECTION_ID_LEN + 2]);
989         enc.encode(&[0xff; 40]); // junk
990 
991         assert!(PublicPacket::decode(&enc, &cid_mgr()).is_err());
992     }
993 
994     const SAMPLE_SHORT: &[u8] = &[
995         0x40, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, 0xf4, 0xa8, 0x30, 0x39, 0xc4, 0x7d,
996         0x99, 0xe3, 0x94, 0x1c, 0x9b, 0xb9, 0x7a, 0x30, 0x1d, 0xd5, 0x8f, 0xf3, 0xdd, 0xa9,
997     ];
998     const SAMPLE_SHORT_PAYLOAD: &[u8] = &[0; 3];
999 
1000     #[test]
build_short()1001     fn build_short() {
1002         fixture_init();
1003         let mut builder =
1004             PacketBuilder::short(Encoder::new(), true, &ConnectionId::from(SERVER_CID));
1005         builder.pn(0, 1);
1006         builder.encode(SAMPLE_SHORT_PAYLOAD); // Enough payload for sampling.
1007         let packet = builder
1008             .build(&mut CryptoDxState::test_default())
1009             .expect("build");
1010         assert_eq!(&packet[..], SAMPLE_SHORT);
1011     }
1012 
1013     #[test]
scramble_short()1014     fn scramble_short() {
1015         fixture_init();
1016         let mut firsts = Vec::new();
1017         for _ in 0..64 {
1018             let mut builder =
1019                 PacketBuilder::short(Encoder::new(), true, &ConnectionId::from(SERVER_CID));
1020             builder.scramble(true);
1021             builder.pn(0, 1);
1022             firsts.push(builder[0]);
1023         }
1024         let is_set = |bit| move |v| v & bit == bit;
1025         // There should be at least one value with the QUIC bit set:
1026         assert!(firsts.iter().any(is_set(PACKET_BIT_FIXED_QUIC)));
1027         // ... but not all:
1028         assert!(!firsts.iter().all(is_set(PACKET_BIT_FIXED_QUIC)));
1029         // There should be at least one value with the spin bit set:
1030         assert!(firsts.iter().any(is_set(PACKET_BIT_SPIN)));
1031         // ... but not all:
1032         assert!(!firsts.iter().all(is_set(PACKET_BIT_SPIN)));
1033     }
1034 
1035     #[test]
decode_short()1036     fn decode_short() {
1037         fixture_init();
1038         let (packet, remainder) = PublicPacket::decode(SAMPLE_SHORT, &cid_mgr()).unwrap();
1039         assert_eq!(packet.packet_type(), PacketType::Short);
1040         assert!(remainder.is_empty());
1041         let decrypted = packet
1042             .decrypt(&mut CryptoStates::test_default(), now())
1043             .unwrap();
1044         assert_eq!(&decrypted[..], SAMPLE_SHORT_PAYLOAD);
1045     }
1046 
1047     /// By telling the decoder that the connection ID is shorter than it really is, we get a decryption error.
1048     #[test]
decode_short_bad_cid()1049     fn decode_short_bad_cid() {
1050         fixture_init();
1051         let (packet, remainder) = PublicPacket::decode(
1052             SAMPLE_SHORT,
1053             &RandomConnectionIdGenerator::new(SERVER_CID.len() - 1),
1054         )
1055         .unwrap();
1056         assert_eq!(packet.packet_type(), PacketType::Short);
1057         assert!(remainder.is_empty());
1058         assert!(packet
1059             .decrypt(&mut CryptoStates::test_default(), now())
1060             .is_err());
1061     }
1062 
1063     /// Saying that the connection ID is longer causes the initial decode to fail.
1064     #[test]
decode_short_long_cid()1065     fn decode_short_long_cid() {
1066         assert!(PublicPacket::decode(
1067             SAMPLE_SHORT,
1068             &RandomConnectionIdGenerator::new(SERVER_CID.len() + 1)
1069         )
1070         .is_err());
1071     }
1072 
1073     #[test]
build_two()1074     fn build_two() {
1075         fixture_init();
1076         let mut prot = CryptoDxState::test_default();
1077         let mut builder = PacketBuilder::long(
1078             Encoder::new(),
1079             PacketType::Handshake,
1080             QuicVersion::default(),
1081             &ConnectionId::from(SERVER_CID),
1082             &ConnectionId::from(CLIENT_CID),
1083         );
1084         builder.pn(0, 1);
1085         builder.encode(&[0; 3]);
1086         let encoder = builder.build(&mut prot).expect("build");
1087         assert_eq!(encoder.len(), 45);
1088         let first = encoder.clone();
1089 
1090         let mut builder = PacketBuilder::short(encoder, false, &ConnectionId::from(SERVER_CID));
1091         builder.pn(1, 3);
1092         builder.encode(&[0]); // Minimal size (packet number is big enough).
1093         let encoder = builder.build(&mut prot).expect("build");
1094         assert_eq!(
1095             &first[..],
1096             &encoder[..first.len()],
1097             "the first packet should be a prefix"
1098         );
1099         assert_eq!(encoder.len(), 45 + 29);
1100     }
1101 
1102     #[test]
build_long()1103     fn build_long() {
1104         const EXPECTED: &[u8] = &[
1105             0xe4, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x40, 0x14, 0xfb, 0xa9, 0x32, 0x3a, 0xf8,
1106             0xbb, 0x18, 0x63, 0xc6, 0xbd, 0x78, 0x0e, 0xba, 0x0c, 0x98, 0x65, 0x58, 0xc9, 0x62,
1107             0x31,
1108         ];
1109 
1110         fixture_init();
1111         let mut builder = PacketBuilder::long(
1112             Encoder::new(),
1113             PacketType::Handshake,
1114             QuicVersion::default(),
1115             &ConnectionId::from(&[][..]),
1116             &ConnectionId::from(&[][..]),
1117         );
1118         builder.pn(0, 1);
1119         builder.encode(&[1, 2, 3]);
1120         let packet = builder.build(&mut CryptoDxState::test_default()).unwrap();
1121         assert_eq!(&packet[..], EXPECTED);
1122     }
1123 
1124     #[test]
scramble_long()1125     fn scramble_long() {
1126         fixture_init();
1127         let mut found_unset = false;
1128         let mut found_set = false;
1129         for _ in 1..64 {
1130             let mut builder = PacketBuilder::long(
1131                 Encoder::new(),
1132                 PacketType::Handshake,
1133                 QuicVersion::default(),
1134                 &ConnectionId::from(&[][..]),
1135                 &ConnectionId::from(&[][..]),
1136             );
1137             builder.pn(0, 1);
1138             builder.scramble(true);
1139             if (builder[0] & PACKET_BIT_FIXED_QUIC) == 0 {
1140                 found_unset = true;
1141             } else {
1142                 found_set = true;
1143             }
1144         }
1145         assert!(found_unset);
1146         assert!(found_set);
1147     }
1148 
1149     #[test]
build_abort()1150     fn build_abort() {
1151         let mut builder = PacketBuilder::long(
1152             Encoder::new(),
1153             PacketType::Initial,
1154             QuicVersion::default(),
1155             &ConnectionId::from(&[][..]),
1156             &ConnectionId::from(SERVER_CID),
1157         );
1158         assert_ne!(builder.remaining(), 0);
1159         builder.initial_token(&[]);
1160         assert_ne!(builder.remaining(), 0);
1161         builder.pn(1, 2);
1162         assert_ne!(builder.remaining(), 0);
1163         let encoder = builder.abort();
1164         assert!(encoder.is_empty());
1165     }
1166 
1167     #[test]
build_insufficient_space()1168     fn build_insufficient_space() {
1169         fixture_init();
1170 
1171         let mut builder = PacketBuilder::short(
1172             Encoder::with_capacity(100),
1173             true,
1174             &ConnectionId::from(SERVER_CID),
1175         );
1176         builder.pn(0, 1);
1177         // Pad, but not up to the full capacity. Leave enough space for the
1178         // AEAD expansion and some extra, but not for an entire long header.
1179         builder.set_limit(75);
1180         builder.enable_padding(true);
1181         assert!(builder.pad());
1182         let encoder = builder.build(&mut CryptoDxState::test_default()).unwrap();
1183         let encoder_copy = encoder.clone();
1184 
1185         let builder = PacketBuilder::long(
1186             encoder,
1187             PacketType::Initial,
1188             QuicVersion::default(),
1189             &ConnectionId::from(SERVER_CID),
1190             &ConnectionId::from(SERVER_CID),
1191         );
1192         assert_eq!(builder.remaining(), 0);
1193         assert_eq!(builder.abort(), encoder_copy);
1194     }
1195 
1196     const SAMPLE_RETRY_V1: &[u8] = &[
1197         0xff, 0x00, 0x00, 0x00, 0x01, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
1198         0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x04, 0xa2, 0x65, 0xba, 0x2e, 0xff, 0x4d, 0x82, 0x90, 0x58,
1199         0xfb, 0x3f, 0x0f, 0x24, 0x96, 0xba,
1200     ];
1201 
1202     const SAMPLE_RETRY_29: &[u8] = &[
1203         0xff, 0xff, 0x00, 0x00, 0x1d, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
1204         0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xd1, 0x69, 0x26, 0xd8, 0x1f, 0x6f, 0x9c, 0xa2, 0x95, 0x3a,
1205         0x8a, 0xa4, 0x57, 0x5e, 0x1e, 0x49,
1206     ];
1207 
1208     const SAMPLE_RETRY_30: &[u8] = &[
1209         0xff, 0xff, 0x00, 0x00, 0x1e, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
1210         0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2d, 0x3e, 0x04, 0x5d, 0x6d, 0x39, 0x20, 0x67, 0x89, 0x94,
1211         0x37, 0x10, 0x8c, 0xe0, 0x0a, 0x61,
1212     ];
1213 
1214     const SAMPLE_RETRY_31: &[u8] = &[
1215         0xff, 0xff, 0x00, 0x00, 0x1f, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
1216         0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xc7, 0x0c, 0xe5, 0xde, 0x43, 0x0b, 0x4b, 0xdb, 0x7d, 0xf1,
1217         0xa3, 0x83, 0x3a, 0x75, 0xf9, 0x86,
1218     ];
1219 
1220     const SAMPLE_RETRY_32: &[u8] = &[
1221         0xff, 0xff, 0x00, 0x00, 0x20, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
1222         0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x59, 0x75, 0x65, 0x19, 0xdd, 0x6c, 0xc8, 0x5b, 0xd9, 0x0e,
1223         0x33, 0xa9, 0x34, 0xd2, 0xff, 0x85,
1224     ];
1225 
1226     const RETRY_TOKEN: &[u8] = b"token";
1227 
build_retry_single(quic_version: QuicVersion, sample_retry: &[u8])1228     fn build_retry_single(quic_version: QuicVersion, sample_retry: &[u8]) {
1229         fixture_init();
1230         let retry =
1231             PacketBuilder::retry(quic_version, &[], SERVER_CID, RETRY_TOKEN, CLIENT_CID).unwrap();
1232 
1233         let (packet, remainder) = PublicPacket::decode(&retry, &cid_mgr()).unwrap();
1234         assert!(packet.is_valid_retry(&ConnectionId::from(CLIENT_CID)));
1235         assert!(remainder.is_empty());
1236 
1237         // The builder adds randomness, which makes expectations hard.
1238         // So only do a full check when that randomness matches up.
1239         if retry[0] == sample_retry[0] {
1240             assert_eq!(&retry, &sample_retry);
1241         } else {
1242             // Otherwise, just check that the header is OK.
1243             assert_eq!(retry[0] & 0xf0, 0xf0);
1244             let header_range = 1..retry.len() - 16;
1245             assert_eq!(&retry[header_range.clone()], &sample_retry[header_range]);
1246         }
1247     }
1248 
1249     #[test]
build_retry_v1()1250     fn build_retry_v1() {
1251         build_retry_single(QuicVersion::Version1, SAMPLE_RETRY_V1);
1252     }
1253 
1254     #[test]
build_retry_29()1255     fn build_retry_29() {
1256         build_retry_single(QuicVersion::Draft29, SAMPLE_RETRY_29);
1257     }
1258 
1259     #[test]
build_retry_30()1260     fn build_retry_30() {
1261         build_retry_single(QuicVersion::Draft30, SAMPLE_RETRY_30);
1262     }
1263 
1264     #[test]
build_retry_31()1265     fn build_retry_31() {
1266         build_retry_single(QuicVersion::Draft31, SAMPLE_RETRY_31);
1267     }
1268 
1269     #[test]
build_retry_32()1270     fn build_retry_32() {
1271         build_retry_single(QuicVersion::Draft32, SAMPLE_RETRY_32);
1272     }
1273 
1274     #[test]
build_retry_multiple()1275     fn build_retry_multiple() {
1276         // Run the build_retry test a few times.
1277         // Odds are approximately 1 in 8 that the full comparison doesn't happen
1278         // for a given version.
1279         for _ in 0..32 {
1280             build_retry_v1();
1281             build_retry_29();
1282             build_retry_30();
1283             build_retry_31();
1284             build_retry_32();
1285         }
1286     }
1287 
decode_retry(quic_version: QuicVersion, sample_retry: &[u8])1288     fn decode_retry(quic_version: QuicVersion, sample_retry: &[u8]) {
1289         fixture_init();
1290         let (packet, remainder) =
1291             PublicPacket::decode(sample_retry, &RandomConnectionIdGenerator::new(5)).unwrap();
1292         assert!(packet.is_valid_retry(&ConnectionId::from(CLIENT_CID)));
1293         assert_eq!(Some(quic_version), packet.quic_version);
1294         assert!(packet.dcid().is_empty());
1295         assert_eq!(&packet.scid()[..], SERVER_CID);
1296         assert_eq!(packet.token(), RETRY_TOKEN);
1297         assert!(remainder.is_empty());
1298     }
1299 
1300     #[test]
decode_retry_29()1301     fn decode_retry_29() {
1302         decode_retry(QuicVersion::Draft29, SAMPLE_RETRY_29);
1303     }
1304 
1305     #[test]
decode_retry_30()1306     fn decode_retry_30() {
1307         decode_retry(QuicVersion::Draft30, SAMPLE_RETRY_30);
1308     }
1309 
1310     #[test]
decode_retry_31()1311     fn decode_retry_31() {
1312         decode_retry(QuicVersion::Draft31, SAMPLE_RETRY_31);
1313     }
1314 
1315     #[test]
decode_retry_32()1316     fn decode_retry_32() {
1317         decode_retry(QuicVersion::Draft32, SAMPLE_RETRY_32);
1318     }
1319 
1320     /// Check some packets that are clearly not valid Retry packets.
1321     #[test]
invalid_retry()1322     fn invalid_retry() {
1323         fixture_init();
1324         let cid_mgr = RandomConnectionIdGenerator::new(5);
1325         let odcid = ConnectionId::from(CLIENT_CID);
1326 
1327         assert!(PublicPacket::decode(&[], &cid_mgr).is_err());
1328 
1329         let (packet, remainder) = PublicPacket::decode(SAMPLE_RETRY_29, &cid_mgr).unwrap();
1330         assert!(remainder.is_empty());
1331         assert!(packet.is_valid_retry(&odcid));
1332 
1333         let mut damaged_retry = SAMPLE_RETRY_29.to_vec();
1334         let last = damaged_retry.len() - 1;
1335         damaged_retry[last] ^= 66;
1336         let (packet, remainder) = PublicPacket::decode(&damaged_retry, &cid_mgr).unwrap();
1337         assert!(remainder.is_empty());
1338         assert!(!packet.is_valid_retry(&odcid));
1339 
1340         damaged_retry.truncate(last);
1341         let (packet, remainder) = PublicPacket::decode(&damaged_retry, &cid_mgr).unwrap();
1342         assert!(remainder.is_empty());
1343         assert!(!packet.is_valid_retry(&odcid));
1344 
1345         // An invalid token should be rejected sooner.
1346         damaged_retry.truncate(last - 4);
1347         assert!(PublicPacket::decode(&damaged_retry, &cid_mgr).is_err());
1348 
1349         damaged_retry.truncate(last - 1);
1350         assert!(PublicPacket::decode(&damaged_retry, &cid_mgr).is_err());
1351     }
1352 
1353     const SAMPLE_VN: &[u8] = &[
1354         0x80, 0x00, 0x00, 0x00, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, 0x08,
1355         0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08, 0x00, 0x00, 0x00, 0x01, 0xff, 0x00, 0x00,
1356         0x1d, 0xff, 0x00, 0x00, 0x1e, 0xff, 0x00, 0x00, 0x1f, 0xff, 0x00, 0x00, 0x20, 0x0a, 0x0a,
1357         0x0a, 0x0a,
1358     ];
1359 
1360     #[test]
build_vn()1361     fn build_vn() {
1362         fixture_init();
1363         let mut vn = PacketBuilder::version_negotiation(SERVER_CID, CLIENT_CID);
1364         // Erase randomness from greasing...
1365         assert_eq!(vn.len(), SAMPLE_VN.len());
1366         vn[0] &= 0x80;
1367         for v in vn.iter_mut().skip(SAMPLE_VN.len() - 4) {
1368             *v &= 0x0f;
1369         }
1370         assert_eq!(&vn, &SAMPLE_VN);
1371     }
1372 
1373     #[test]
parse_vn()1374     fn parse_vn() {
1375         let (packet, remainder) =
1376             PublicPacket::decode(SAMPLE_VN, &EmptyConnectionIdGenerator::default()).unwrap();
1377         assert!(remainder.is_empty());
1378         assert_eq!(&packet.dcid[..], SERVER_CID);
1379         assert!(packet.scid.is_some());
1380         assert_eq!(&packet.scid.unwrap()[..], CLIENT_CID);
1381     }
1382 
1383     /// A Version Negotiation packet can have a long connection ID.
1384     #[test]
parse_vn_big_cid()1385     fn parse_vn_big_cid() {
1386         const BIG_DCID: &[u8] = &[0x44; MAX_CONNECTION_ID_LEN + 1];
1387         const BIG_SCID: &[u8] = &[0xee; 255];
1388 
1389         let mut enc = Encoder::from(&[0xff, 0x00, 0x00, 0x00, 0x00][..]);
1390         enc.encode_vec(1, BIG_DCID);
1391         enc.encode_vec(1, BIG_SCID);
1392         enc.encode_uint(4, 0x1a2a_3a4a_u64);
1393         enc.encode_uint(4, QuicVersion::default().as_u32());
1394         enc.encode_uint(4, 0x5a6a_7a8a_u64);
1395 
1396         let (packet, remainder) =
1397             PublicPacket::decode(&enc, &EmptyConnectionIdGenerator::default()).unwrap();
1398         assert!(remainder.is_empty());
1399         assert_eq!(&packet.dcid[..], BIG_DCID);
1400         assert!(packet.scid.is_some());
1401         assert_eq!(&packet.scid.unwrap()[..], BIG_SCID);
1402     }
1403 
1404     #[test]
decode_pn()1405     fn decode_pn() {
1406         // When the expected value is low, the value doesn't go negative.
1407         assert_eq!(PublicPacket::decode_pn(0, 0, 1), 0);
1408         assert_eq!(PublicPacket::decode_pn(0, 0xff, 1), 0xff);
1409         assert_eq!(PublicPacket::decode_pn(10, 0, 1), 0);
1410         assert_eq!(PublicPacket::decode_pn(0x7f, 0, 1), 0);
1411         assert_eq!(PublicPacket::decode_pn(0x80, 0, 1), 0x100);
1412         assert_eq!(PublicPacket::decode_pn(0x80, 2, 1), 2);
1413         assert_eq!(PublicPacket::decode_pn(0x80, 0xff, 1), 0xff);
1414         assert_eq!(PublicPacket::decode_pn(0x7ff, 0xfe, 1), 0x7fe);
1415 
1416         // This is invalid by spec, as we are expected to check for overflow around 2^62-1,
1417         // but we don't need to worry about overflow
1418         // and hitting this is basically impossible in practice.
1419         assert_eq!(
1420             PublicPacket::decode_pn(0x3fff_ffff_ffff_ffff, 2, 4),
1421             0x4000_0000_0000_0002
1422         );
1423     }
1424 
1425     #[test]
chacha20_sample()1426     fn chacha20_sample() {
1427         const PACKET: &[u8] = &[
1428             0x4c, 0xfe, 0x41, 0x89, 0x65, 0x5e, 0x5c, 0xd5, 0x5c, 0x41, 0xf6, 0x90, 0x80, 0x57,
1429             0x5d, 0x79, 0x99, 0xc2, 0x5a, 0x5b, 0xfb,
1430         ];
1431         fixture_init();
1432         let (packet, slice) =
1433             PublicPacket::decode(PACKET, &EmptyConnectionIdGenerator::default()).unwrap();
1434         assert!(slice.is_empty());
1435         let decrypted = packet
1436             .decrypt(&mut CryptoStates::test_chacha(), now())
1437             .unwrap();
1438         assert_eq!(decrypted.packet_type(), PacketType::Short);
1439         assert_eq!(decrypted.pn(), 654_360_564);
1440         assert_eq!(&decrypted[..], &[0x01]);
1441     }
1442 }
1443