1 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4 // option. This file may not be copied, modified, or distributed
5 // except according to those terms.
6 
7 // Transport parameters. See -transport section 7.3.
8 
9 use crate::cid::{
10     ConnectionId, ConnectionIdEntry, CONNECTION_ID_SEQNO_PREFERRED, MAX_CONNECTION_ID_LEN,
11 };
12 use crate::{Error, Res};
13 
14 use neqo_common::{hex, qdebug, qinfo, qtrace, Decoder, Encoder};
15 use neqo_crypto::constants::{TLS_HS_CLIENT_HELLO, TLS_HS_ENCRYPTED_EXTENSIONS};
16 use neqo_crypto::ext::{ExtensionHandler, ExtensionHandlerResult, ExtensionWriterResult};
17 use neqo_crypto::{HandshakeMessage, ZeroRttCheckResult, ZeroRttChecker};
18 
19 use std::cell::RefCell;
20 use std::collections::HashMap;
21 use std::convert::TryFrom;
22 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
23 use std::rc::Rc;
24 
25 pub type TransportParameterId = u64;
26 macro_rules! tpids {
27         { $($n:ident = $v:expr),+ $(,)? } => {
28             $(pub const $n: TransportParameterId = $v as TransportParameterId;)+
29         };
30     }
31 tpids! {
32     ORIGINAL_DESTINATION_CONNECTION_ID = 0x00,
33     IDLE_TIMEOUT = 0x01,
34     STATELESS_RESET_TOKEN = 0x02,
35     MAX_UDP_PAYLOAD_SIZE = 0x03,
36     INITIAL_MAX_DATA = 0x04,
37     INITIAL_MAX_STREAM_DATA_BIDI_LOCAL = 0x05,
38     INITIAL_MAX_STREAM_DATA_BIDI_REMOTE = 0x06,
39     INITIAL_MAX_STREAM_DATA_UNI = 0x07,
40     INITIAL_MAX_STREAMS_BIDI = 0x08,
41     INITIAL_MAX_STREAMS_UNI = 0x09,
42     ACK_DELAY_EXPONENT = 0x0a,
43     MAX_ACK_DELAY = 0x0b,
44     DISABLE_MIGRATION = 0x0c,
45     PREFERRED_ADDRESS = 0x0d,
46     ACTIVE_CONNECTION_ID_LIMIT = 0x0e,
47     INITIAL_SOURCE_CONNECTION_ID = 0x0f,
48     RETRY_SOURCE_CONNECTION_ID = 0x10,
49     GREASE_QUIC_BIT = 0x2ab2,
50     MIN_ACK_DELAY = 0xff02_de1a,
51 }
52 
53 #[derive(Clone, Debug)]
54 pub struct PreferredAddress {
55     v4: Option<SocketAddr>,
56     v6: Option<SocketAddr>,
57 }
58 
59 impl PreferredAddress {
60     /// Make a new preferred address configuration.
61     ///
62     /// # Panics
63     /// If neither address is provided, or if either address is of the wrong type.
64     #[must_use]
new(v4: Option<SocketAddr>, v6: Option<SocketAddr>) -> Self65     pub fn new(v4: Option<SocketAddr>, v6: Option<SocketAddr>) -> Self {
66         assert!(v4.is_some() || v6.is_some());
67         if let Some(a) = v4 {
68             if let IpAddr::V4(addr) = a.ip() {
69                 assert!(!addr.is_unspecified());
70             } else {
71                 panic!("invalid address type for v4 address");
72             }
73             assert_ne!(a.port(), 0);
74         }
75         if let Some(a) = v6 {
76             if let IpAddr::V6(addr) = a.ip() {
77                 assert!(!addr.is_unspecified());
78             } else {
79                 panic!("invalid address type for v6 address");
80             }
81             assert_ne!(a.port(), 0);
82         }
83         Self { v4, v6 }
84     }
85 
86     #[must_use]
ipv4(&self) -> Option<SocketAddr>87     pub fn ipv4(&self) -> Option<SocketAddr> {
88         self.v4
89     }
90     #[must_use]
ipv6(&self) -> Option<SocketAddr>91     pub fn ipv6(&self) -> Option<SocketAddr> {
92         self.v6
93     }
94 }
95 
96 #[derive(Clone, Debug, PartialEq)]
97 pub enum TransportParameter {
98     Bytes(Vec<u8>),
99     Integer(u64),
100     Empty,
101     PreferredAddress {
102         v4: Option<SocketAddr>,
103         v6: Option<SocketAddr>,
104         cid: ConnectionId,
105         srt: [u8; 16],
106     },
107 }
108 
109 impl TransportParameter {
encode(&self, enc: &mut Encoder, tp: TransportParameterId)110     fn encode(&self, enc: &mut Encoder, tp: TransportParameterId) {
111         qdebug!("TP encoded; type 0x{:02x} val {:?}", tp, self);
112         enc.encode_varint(tp);
113         match self {
114             Self::Bytes(a) => {
115                 enc.encode_vvec(a);
116             }
117             Self::Integer(a) => {
118                 enc.encode_vvec_with(|enc_inner| {
119                     enc_inner.encode_varint(*a);
120                 });
121             }
122             Self::Empty => {
123                 enc.encode_varint(0_u64);
124             }
125             Self::PreferredAddress { v4, v6, cid, srt } => {
126                 enc.encode_vvec_with(|enc_inner| {
127                     if let Some(v4) = v4 {
128                         debug_assert!(v4.is_ipv4());
129                         if let IpAddr::V4(a) = v4.ip() {
130                             enc_inner.encode(&a.octets()[..]);
131                         } else {
132                             unreachable!();
133                         }
134                         enc_inner.encode_uint(2, v4.port());
135                     } else {
136                         enc_inner.encode(&[0; 6]);
137                     }
138                     if let Some(v6) = v6 {
139                         debug_assert!(v6.is_ipv6());
140                         if let IpAddr::V6(a) = v6.ip() {
141                             enc_inner.encode(&a.octets()[..]);
142                         } else {
143                             unreachable!();
144                         }
145                         enc_inner.encode_uint(2, v6.port());
146                     } else {
147                         enc_inner.encode(&[0; 18]);
148                     }
149                     enc_inner.encode_vec(1, &cid[..]);
150                     enc_inner.encode(&srt[..]);
151                 });
152             }
153         };
154     }
155 
decode_preferred_address(d: &mut Decoder) -> Res<Self>156     fn decode_preferred_address(d: &mut Decoder) -> Res<Self> {
157         // IPv4 address (maybe)
158         let v4ip =
159             Ipv4Addr::from(<[u8; 4]>::try_from(d.decode(4).ok_or(Error::NoMoreData)?).unwrap());
160         let v4port = u16::try_from(d.decode_uint(2).ok_or(Error::NoMoreData)?).unwrap();
161         // Can't have non-zero IP and zero port, or vice versa.
162         if v4ip.is_unspecified() ^ (v4port == 0) {
163             return Err(Error::TransportParameterError);
164         }
165         let v4 = if v4port == 0 {
166             None
167         } else {
168             Some(SocketAddr::new(IpAddr::V4(v4ip), v4port))
169         };
170 
171         // IPv6 address (mostly the same as v4)
172         let v6ip =
173             Ipv6Addr::from(<[u8; 16]>::try_from(d.decode(16).ok_or(Error::NoMoreData)?).unwrap());
174         let v6port = u16::try_from(d.decode_uint(2).ok_or(Error::NoMoreData)?).unwrap();
175         if v6ip.is_unspecified() ^ (v6port == 0) {
176             return Err(Error::TransportParameterError);
177         }
178         let v6 = if v6port == 0 {
179             None
180         } else {
181             Some(SocketAddr::new(IpAddr::V6(v6ip), v6port))
182         };
183         // Need either v4 or v6 to be present.
184         if v4.is_none() && v6.is_none() {
185             return Err(Error::TransportParameterError);
186         }
187 
188         // Connection ID (non-zero length)
189         let cid = ConnectionId::from(d.decode_vec(1).ok_or(Error::NoMoreData)?);
190         if cid.len() == 0 || cid.len() > MAX_CONNECTION_ID_LEN {
191             return Err(Error::TransportParameterError);
192         }
193 
194         // Stateless reset token
195         let srtbuf = d.decode(16).ok_or(Error::NoMoreData)?;
196         let srt = <[u8; 16]>::try_from(srtbuf).unwrap();
197 
198         Ok(Self::PreferredAddress { v4, v6, cid, srt })
199     }
200 
decode(dec: &mut Decoder) -> Res<Option<(TransportParameterId, Self)>>201     fn decode(dec: &mut Decoder) -> Res<Option<(TransportParameterId, Self)>> {
202         let tp = dec.decode_varint().ok_or(Error::NoMoreData)?;
203         let content = dec.decode_vvec().ok_or(Error::NoMoreData)?;
204         qtrace!("TP {:x} length {:x}", tp, content.len());
205         let mut d = Decoder::from(content);
206         let value = match tp {
207             ORIGINAL_DESTINATION_CONNECTION_ID
208             | INITIAL_SOURCE_CONNECTION_ID
209             | RETRY_SOURCE_CONNECTION_ID => Self::Bytes(d.decode_remainder().to_vec()),
210             STATELESS_RESET_TOKEN => {
211                 if d.remaining() != 16 {
212                     return Err(Error::TransportParameterError);
213                 }
214                 Self::Bytes(d.decode_remainder().to_vec())
215             }
216             IDLE_TIMEOUT
217             | INITIAL_MAX_DATA
218             | INITIAL_MAX_STREAM_DATA_BIDI_LOCAL
219             | INITIAL_MAX_STREAM_DATA_BIDI_REMOTE
220             | INITIAL_MAX_STREAM_DATA_UNI
221             | MAX_ACK_DELAY => match d.decode_varint() {
222                 Some(v) => Self::Integer(v),
223                 None => return Err(Error::TransportParameterError),
224             },
225 
226             INITIAL_MAX_STREAMS_BIDI | INITIAL_MAX_STREAMS_UNI => match d.decode_varint() {
227                 Some(v) if v <= (1 << 60) => Self::Integer(v),
228                 _ => return Err(Error::StreamLimitError),
229             },
230 
231             MAX_UDP_PAYLOAD_SIZE => match d.decode_varint() {
232                 Some(v) if v >= 1200 => Self::Integer(v),
233                 _ => return Err(Error::TransportParameterError),
234             },
235 
236             ACK_DELAY_EXPONENT => match d.decode_varint() {
237                 Some(v) if v <= 20 => Self::Integer(v),
238                 _ => return Err(Error::TransportParameterError),
239             },
240             ACTIVE_CONNECTION_ID_LIMIT => match d.decode_varint() {
241                 Some(v) if v >= 2 => Self::Integer(v),
242                 _ => return Err(Error::TransportParameterError),
243             },
244 
245             DISABLE_MIGRATION | GREASE_QUIC_BIT => Self::Empty,
246 
247             PREFERRED_ADDRESS => Self::decode_preferred_address(&mut d)?,
248 
249             MIN_ACK_DELAY => match d.decode_varint() {
250                 Some(v) if v < (1 << 24) => Self::Integer(v),
251                 _ => return Err(Error::TransportParameterError),
252             },
253 
254             // Skip.
255             _ => return Ok(None),
256         };
257         if d.remaining() > 0 {
258             return Err(Error::TooMuchData);
259         }
260         qdebug!("TP decoded; type 0x{:02x} val {:?}", tp, value);
261         Ok(Some((tp, value)))
262     }
263 }
264 
265 #[derive(Clone, Debug, Default, PartialEq)]
266 pub struct TransportParameters {
267     params: HashMap<TransportParameterId, TransportParameter>,
268 }
269 
270 impl TransportParameters {
271     /// Set a value.
set(&mut self, k: TransportParameterId, v: TransportParameter)272     pub fn set(&mut self, k: TransportParameterId, v: TransportParameter) {
273         self.params.insert(k, v);
274     }
275 
276     /// Clear a key.
remove(&mut self, k: TransportParameterId)277     pub fn remove(&mut self, k: TransportParameterId) {
278         self.params.remove(&k);
279     }
280 
281     /// Decode is a static function that parses transport parameters
282     /// using the provided decoder.
decode(d: &mut Decoder) -> Res<Self>283     pub(crate) fn decode(d: &mut Decoder) -> Res<Self> {
284         let mut tps = Self::default();
285         qtrace!("Parsed fixed TP header");
286 
287         while d.remaining() > 0 {
288             match TransportParameter::decode(d) {
289                 Ok(Some((tipe, tp))) => {
290                     tps.set(tipe, tp);
291                 }
292                 Ok(None) => {}
293                 Err(e) => return Err(e),
294             }
295         }
296         Ok(tps)
297     }
298 
encode(&self, enc: &mut Encoder)299     pub(crate) fn encode(&self, enc: &mut Encoder) {
300         for (tipe, tp) in &self.params {
301             tp.encode(enc, *tipe);
302         }
303     }
304 
305     // Get an integer type or a default.
get_integer(&self, tp: TransportParameterId) -> u64306     pub fn get_integer(&self, tp: TransportParameterId) -> u64 {
307         let default = match tp {
308             IDLE_TIMEOUT
309             | INITIAL_MAX_DATA
310             | INITIAL_MAX_STREAM_DATA_BIDI_LOCAL
311             | INITIAL_MAX_STREAM_DATA_BIDI_REMOTE
312             | INITIAL_MAX_STREAM_DATA_UNI
313             | INITIAL_MAX_STREAMS_BIDI
314             | INITIAL_MAX_STREAMS_UNI => 0,
315             MAX_UDP_PAYLOAD_SIZE => 65527,
316             ACK_DELAY_EXPONENT => 3,
317             MAX_ACK_DELAY => 25,
318             ACTIVE_CONNECTION_ID_LIMIT => 2,
319             MIN_ACK_DELAY => 0,
320             _ => panic!("Transport parameter not known or not an Integer"),
321         };
322         match self.params.get(&tp) {
323             None => default,
324             Some(TransportParameter::Integer(x)) => *x,
325             _ => panic!("Internal error"),
326         }
327     }
328 
329     // Set an integer type or a default.
set_integer(&mut self, tp: TransportParameterId, value: u64)330     pub fn set_integer(&mut self, tp: TransportParameterId, value: u64) {
331         match tp {
332             IDLE_TIMEOUT
333             | INITIAL_MAX_DATA
334             | INITIAL_MAX_STREAM_DATA_BIDI_LOCAL
335             | INITIAL_MAX_STREAM_DATA_BIDI_REMOTE
336             | INITIAL_MAX_STREAM_DATA_UNI
337             | INITIAL_MAX_STREAMS_BIDI
338             | INITIAL_MAX_STREAMS_UNI
339             | MAX_UDP_PAYLOAD_SIZE
340             | ACK_DELAY_EXPONENT
341             | MAX_ACK_DELAY
342             | ACTIVE_CONNECTION_ID_LIMIT
343             | MIN_ACK_DELAY => {
344                 self.set(tp, TransportParameter::Integer(value));
345             }
346             _ => panic!("Transport parameter not known"),
347         }
348     }
349 
get_bytes(&self, tp: TransportParameterId) -> Option<&[u8]>350     pub fn get_bytes(&self, tp: TransportParameterId) -> Option<&[u8]> {
351         match tp {
352             ORIGINAL_DESTINATION_CONNECTION_ID
353             | INITIAL_SOURCE_CONNECTION_ID
354             | RETRY_SOURCE_CONNECTION_ID
355             | STATELESS_RESET_TOKEN => {}
356             _ => panic!("Transport parameter not known or not type bytes"),
357         }
358 
359         match self.params.get(&tp) {
360             None => None,
361             Some(TransportParameter::Bytes(x)) => Some(&x),
362             _ => panic!("Internal error"),
363         }
364     }
365 
set_bytes(&mut self, tp: TransportParameterId, value: Vec<u8>)366     pub fn set_bytes(&mut self, tp: TransportParameterId, value: Vec<u8>) {
367         match tp {
368             ORIGINAL_DESTINATION_CONNECTION_ID
369             | INITIAL_SOURCE_CONNECTION_ID
370             | RETRY_SOURCE_CONNECTION_ID
371             | STATELESS_RESET_TOKEN => {
372                 self.set(tp, TransportParameter::Bytes(value));
373             }
374             _ => panic!("Transport parameter not known or not type bytes"),
375         }
376     }
377 
set_empty(&mut self, tp: TransportParameterId)378     pub fn set_empty(&mut self, tp: TransportParameterId) {
379         match tp {
380             DISABLE_MIGRATION | GREASE_QUIC_BIT => {
381                 self.set(tp, TransportParameter::Empty);
382             }
383             _ => panic!("Transport parameter not known or not type empty"),
384         }
385     }
386 
get_empty(&self, tipe: TransportParameterId) -> bool387     pub fn get_empty(&self, tipe: TransportParameterId) -> bool {
388         match self.params.get(&tipe) {
389             None => false,
390             Some(TransportParameter::Empty) => true,
391             _ => panic!("Internal error"),
392         }
393     }
394 
395     /// Return true if the remembered transport parameters are OK for 0-RTT.
396     /// Generally this means that any value that is currently in effect is greater than
397     /// or equal to the promised value.
ok_for_0rtt(&self, remembered: &Self) -> bool398     pub(crate) fn ok_for_0rtt(&self, remembered: &Self) -> bool {
399         for (k, v_rem) in &remembered.params {
400             // Skip checks for these, which don't affect 0-RTT.
401             if matches!(
402                 *k,
403                 ORIGINAL_DESTINATION_CONNECTION_ID
404                     | INITIAL_SOURCE_CONNECTION_ID
405                     | RETRY_SOURCE_CONNECTION_ID
406                     | STATELESS_RESET_TOKEN
407                     | IDLE_TIMEOUT
408                     | ACK_DELAY_EXPONENT
409                     | MAX_ACK_DELAY
410                     | ACTIVE_CONNECTION_ID_LIMIT
411                     | PREFERRED_ADDRESS
412             ) {
413                 continue;
414             }
415             if let Some(v_self) = self.params.get(k) {
416                 match (v_self, v_rem) {
417                     (TransportParameter::Integer(i_self), TransportParameter::Integer(i_rem)) => {
418                         if *k == MIN_ACK_DELAY {
419                             // MIN_ACK_DELAY is backwards:
420                             // it can only be reduced safely.
421                             if *i_self > *i_rem {
422                                 return false;
423                             }
424                         } else if *i_self < *i_rem {
425                             return false;
426                         }
427                     }
428                     (TransportParameter::Empty, TransportParameter::Empty) => {}
429                     _ => return false,
430                 }
431             } else {
432                 return false;
433             }
434         }
435         true
436     }
437 
438     /// Get the preferred address in a usable form.
439     #[must_use]
get_preferred_address(&self) -> Option<(PreferredAddress, ConnectionIdEntry<[u8; 16]>)>440     pub fn get_preferred_address(&self) -> Option<(PreferredAddress, ConnectionIdEntry<[u8; 16]>)> {
441         if let Some(TransportParameter::PreferredAddress { v4, v6, cid, srt }) =
442             self.params.get(&PREFERRED_ADDRESS)
443         {
444             Some((
445                 PreferredAddress::new(*v4, *v6),
446                 ConnectionIdEntry::new(CONNECTION_ID_SEQNO_PREFERRED, cid.clone(), *srt),
447             ))
448         } else {
449             None
450         }
451     }
452 
453     #[must_use]
has_value(&self, tp: TransportParameterId) -> bool454     pub fn has_value(&self, tp: TransportParameterId) -> bool {
455         self.params.contains_key(&tp)
456     }
457 }
458 
459 #[derive(Default, Debug)]
460 pub struct TransportParametersHandler {
461     pub(crate) local: TransportParameters,
462     pub(crate) remote: Option<TransportParameters>,
463     pub(crate) remote_0rtt: Option<TransportParameters>,
464 }
465 
466 impl TransportParametersHandler {
remote(&self) -> &TransportParameters467     pub fn remote(&self) -> &TransportParameters {
468         match (self.remote.as_ref(), self.remote_0rtt.as_ref()) {
469             (Some(tp), _) | (_, Some(tp)) => tp,
470             _ => panic!("no transport parameters from peer"),
471         }
472     }
473 }
474 
475 impl ExtensionHandler for TransportParametersHandler {
write(&mut self, msg: HandshakeMessage, d: &mut [u8]) -> ExtensionWriterResult476     fn write(&mut self, msg: HandshakeMessage, d: &mut [u8]) -> ExtensionWriterResult {
477         if !matches!(msg, TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS) {
478             return ExtensionWriterResult::Skip;
479         }
480 
481         qdebug!("Writing transport parameters, msg={:?}", msg);
482 
483         // TODO(ekr@rtfm.com): Modify to avoid a copy.
484         let mut enc = Encoder::default();
485         self.local.encode(&mut enc);
486         assert!(enc.len() <= d.len());
487         d[..enc.len()].copy_from_slice(&enc);
488         ExtensionWriterResult::Write(enc.len())
489     }
490 
handle(&mut self, msg: HandshakeMessage, d: &[u8]) -> ExtensionHandlerResult491     fn handle(&mut self, msg: HandshakeMessage, d: &[u8]) -> ExtensionHandlerResult {
492         qtrace!(
493             "Handling transport parameters, msg={:?} value={}",
494             msg,
495             hex(d),
496         );
497 
498         if !matches!(msg, TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS) {
499             return ExtensionHandlerResult::Alert(110); // unsupported_extension
500         }
501 
502         let mut dec = Decoder::from(d);
503         match TransportParameters::decode(&mut dec) {
504             Ok(tp) => {
505                 self.remote = Some(tp);
506                 ExtensionHandlerResult::Ok
507             }
508             _ => ExtensionHandlerResult::Alert(47), // illegal_parameter
509         }
510     }
511 }
512 
513 #[derive(Debug)]
514 pub(crate) struct TpZeroRttChecker<T> {
515     handler: Rc<RefCell<TransportParametersHandler>>,
516     app_checker: T,
517 }
518 
519 impl<T> TpZeroRttChecker<T>
520 where
521     T: ZeroRttChecker + 'static,
522 {
wrap( handler: Rc<RefCell<TransportParametersHandler>>, app_checker: T, ) -> Box<dyn ZeroRttChecker>523     pub fn wrap(
524         handler: Rc<RefCell<TransportParametersHandler>>,
525         app_checker: T,
526     ) -> Box<dyn ZeroRttChecker> {
527         Box::new(Self {
528             handler,
529             app_checker,
530         })
531     }
532 }
533 
534 impl<T> ZeroRttChecker for TpZeroRttChecker<T>
535 where
536     T: ZeroRttChecker,
537 {
check(&self, token: &[u8]) -> ZeroRttCheckResult538     fn check(&self, token: &[u8]) -> ZeroRttCheckResult {
539         // Reject 0-RTT if there is no token.
540         if token.is_empty() {
541             qdebug!("0-RTT: no token, no 0-RTT");
542             return ZeroRttCheckResult::Reject;
543         }
544         let mut dec = Decoder::from(token);
545         let tpslice = if let Some(v) = dec.decode_vvec() {
546             v
547         } else {
548             qinfo!("0-RTT: token code error");
549             return ZeroRttCheckResult::Fail;
550         };
551         let mut dec_tp = Decoder::from(tpslice);
552         let remembered = if let Ok(v) = TransportParameters::decode(&mut dec_tp) {
553             v
554         } else {
555             qinfo!("0-RTT: transport parameter decode error");
556             return ZeroRttCheckResult::Fail;
557         };
558         if self.handler.borrow().local.ok_for_0rtt(&remembered) {
559             qinfo!("0-RTT: transport parameters OK, passing to application checker");
560             self.app_checker.check(dec.decode_remainder())
561         } else {
562             qinfo!("0-RTT: transport parameters bad, rejecting");
563             ZeroRttCheckResult::Reject
564         }
565     }
566 }
567 
568 // TODO(ekr@rtfm.com): Need to write more TP unit tests.
569 #[cfg(test)]
570 #[allow(unused_variables)]
571 mod tests {
572     use super::*;
573     use std::mem;
574 
575     #[test]
basic_tps()576     fn basic_tps() {
577         const RESET_TOKEN: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8];
578         let mut tps = TransportParameters::default();
579         tps.set(
580             STATELESS_RESET_TOKEN,
581             TransportParameter::Bytes(RESET_TOKEN.to_vec()),
582         );
583         tps.params
584             .insert(INITIAL_MAX_STREAMS_BIDI, TransportParameter::Integer(10));
585 
586         let mut enc = Encoder::default();
587         tps.encode(&mut enc);
588 
589         let tps2 = TransportParameters::decode(&mut enc.as_decoder()).expect("Couldn't decode");
590         assert_eq!(tps, tps2);
591 
592         println!("TPS = {:?}", tps);
593         assert_eq!(tps2.get_integer(IDLE_TIMEOUT), 0); // Default
594         assert_eq!(tps2.get_integer(MAX_ACK_DELAY), 25); // Default
595         assert_eq!(tps2.get_integer(ACTIVE_CONNECTION_ID_LIMIT), 2); // Default
596         assert_eq!(tps2.get_integer(INITIAL_MAX_STREAMS_BIDI), 10); // Sent
597         assert_eq!(tps2.get_bytes(STATELESS_RESET_TOKEN), Some(RESET_TOKEN));
598         assert_eq!(tps2.get_bytes(ORIGINAL_DESTINATION_CONNECTION_ID), None);
599         assert_eq!(tps2.get_bytes(INITIAL_SOURCE_CONNECTION_ID), None);
600         assert_eq!(tps2.get_bytes(RETRY_SOURCE_CONNECTION_ID), None);
601         assert!(!tps2.has_value(ORIGINAL_DESTINATION_CONNECTION_ID));
602         assert!(!tps2.has_value(INITIAL_SOURCE_CONNECTION_ID));
603         assert!(!tps2.has_value(RETRY_SOURCE_CONNECTION_ID));
604         assert!(tps2.has_value(STATELESS_RESET_TOKEN));
605 
606         let mut enc = Encoder::default();
607         tps.encode(&mut enc);
608 
609         let tps2 = TransportParameters::decode(&mut enc.as_decoder()).expect("Couldn't decode");
610     }
611 
make_spa() -> TransportParameter612     fn make_spa() -> TransportParameter {
613         TransportParameter::PreferredAddress {
614             v4: Some(SocketAddr::new(
615                 IpAddr::V4(Ipv4Addr::from(0xc000_0201)),
616                 443,
617             )),
618             v6: Some(SocketAddr::new(
619                 IpAddr::V6(Ipv6Addr::from(0xfe80_0000_0000_0000_0000_0000_0000_0001)),
620                 443,
621             )),
622             cid: ConnectionId::from(&[1, 2, 3, 4, 5]),
623             srt: [3; 16],
624         }
625     }
626 
627     #[test]
preferred_address_encode_decode()628     fn preferred_address_encode_decode() {
629         const ENCODED: &[u8] = &[
630             0x0d, 0x2e, 0xc0, 0x00, 0x02, 0x01, 0x01, 0xbb, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00,
631             0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0xbb, 0x05, 0x01,
632             0x02, 0x03, 0x04, 0x05, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03,
633             0x03, 0x03, 0x03, 0x03, 0x03, 0x03,
634         ];
635         let spa = make_spa();
636         let mut enc = Encoder::new();
637         spa.encode(&mut enc, PREFERRED_ADDRESS);
638         assert_eq!(&enc[..], ENCODED);
639 
640         let mut dec = enc.as_decoder();
641         let (id, decoded) = TransportParameter::decode(&mut dec).unwrap().unwrap();
642         assert_eq!(id, PREFERRED_ADDRESS);
643         assert_eq!(decoded, spa);
644     }
645 
mutate_spa<F>(wrecker: F) -> TransportParameter where F: FnOnce(&mut Option<SocketAddr>, &mut Option<SocketAddr>, &mut ConnectionId),646     fn mutate_spa<F>(wrecker: F) -> TransportParameter
647     where
648         F: FnOnce(&mut Option<SocketAddr>, &mut Option<SocketAddr>, &mut ConnectionId),
649     {
650         let mut spa = make_spa();
651         if let TransportParameter::PreferredAddress {
652             ref mut v4,
653             ref mut v6,
654             ref mut cid,
655             ..
656         } = &mut spa
657         {
658             wrecker(v4, v6, cid);
659         } else {
660             unreachable!();
661         }
662         spa
663     }
664 
665     /// This takes a `TransportParameter::PreferredAddress` that has been mutilated.
666     /// It then encodes it, working from the knowledge that the `encode` function
667     /// doesn't care about validity, and decodes it.  The result should be failure.
assert_invalid_spa(spa: TransportParameter)668     fn assert_invalid_spa(spa: TransportParameter) {
669         let mut enc = Encoder::new();
670         spa.encode(&mut enc, PREFERRED_ADDRESS);
671         assert_eq!(
672             TransportParameter::decode(&mut enc.as_decoder()).unwrap_err(),
673             Error::TransportParameterError
674         );
675     }
676 
677     /// This is for those rare mutations that are acceptable.
assert_valid_spa(spa: TransportParameter)678     fn assert_valid_spa(spa: TransportParameter) {
679         let mut enc = Encoder::new();
680         spa.encode(&mut enc, PREFERRED_ADDRESS);
681         let mut dec = enc.as_decoder();
682         let (id, decoded) = TransportParameter::decode(&mut dec).unwrap().unwrap();
683         assert_eq!(id, PREFERRED_ADDRESS);
684         assert_eq!(decoded, spa);
685     }
686 
687     #[test]
preferred_address_zero_address()688     fn preferred_address_zero_address() {
689         // Either port being zero is bad.
690         assert_invalid_spa(mutate_spa(|v4, _, _| {
691             v4.as_mut().unwrap().set_port(0);
692         }));
693         assert_invalid_spa(mutate_spa(|_, v6, _| {
694             v6.as_mut().unwrap().set_port(0);
695         }));
696         // Either IP being zero is bad.
697         assert_invalid_spa(mutate_spa(|v4, _, _| {
698             v4.as_mut().unwrap().set_ip(IpAddr::V4(Ipv4Addr::from(0)));
699         }));
700         assert_invalid_spa(mutate_spa(|_, v6, _| {
701             v6.as_mut().unwrap().set_ip(IpAddr::V6(Ipv6Addr::from(0)));
702         }));
703         // Either address being absent is OK.
704         assert_valid_spa(mutate_spa(|v4, _, _| {
705             *v4 = None;
706         }));
707         assert_valid_spa(mutate_spa(|_, v6, _| {
708             *v6 = None;
709         }));
710         // Both addresses being absent is bad.
711         assert_invalid_spa(mutate_spa(|v4, v6, _| {
712             *v4 = None;
713             *v6 = None;
714         }));
715     }
716 
717     #[test]
preferred_address_bad_cid()718     fn preferred_address_bad_cid() {
719         assert_invalid_spa(mutate_spa(|_, _, cid| {
720             *cid = ConnectionId::from(&[]);
721         }));
722         assert_invalid_spa(mutate_spa(|_, _, cid| {
723             *cid = ConnectionId::from(&[0x0c; 21]);
724         }));
725     }
726 
727     #[test]
preferred_address_truncated()728     fn preferred_address_truncated() {
729         let spa = make_spa();
730         let mut enc = Encoder::new();
731         spa.encode(&mut enc, PREFERRED_ADDRESS);
732         let mut dec = Decoder::from(&enc[..enc.len() - 1]);
733         assert_eq!(
734             TransportParameter::decode(&mut dec).unwrap_err(),
735             Error::NoMoreData
736         );
737     }
738 
739     #[test]
740     #[should_panic]
preferred_address_wrong_family_v4()741     fn preferred_address_wrong_family_v4() {
742         mutate_spa(|v4, _, _| {
743             v4.as_mut().unwrap().set_ip(IpAddr::V6(Ipv6Addr::from(0)));
744         })
745         .encode(&mut Encoder::new(), PREFERRED_ADDRESS);
746     }
747 
748     #[test]
749     #[should_panic]
preferred_address_wrong_family_v6()750     fn preferred_address_wrong_family_v6() {
751         mutate_spa(|_, v6, _| {
752             v6.as_mut().unwrap().set_ip(IpAddr::V4(Ipv4Addr::from(0)));
753         })
754         .encode(&mut Encoder::new(), PREFERRED_ADDRESS);
755     }
756 
757     #[test]
758     #[should_panic]
preferred_address_neither()759     fn preferred_address_neither() {
760         mem::drop(PreferredAddress::new(None, None));
761     }
762 
763     #[test]
764     #[should_panic]
preferred_address_v4_unspecified()765     fn preferred_address_v4_unspecified() {
766         let _ = PreferredAddress::new(
767             Some(SocketAddr::new(IpAddr::V4(Ipv4Addr::from(0)), 443)),
768             None,
769         );
770     }
771 
772     #[test]
773     #[should_panic]
preferred_address_v4_zero_port()774     fn preferred_address_v4_zero_port() {
775         let _ = PreferredAddress::new(
776             Some(SocketAddr::new(IpAddr::V4(Ipv4Addr::from(0xc000_0201)), 0)),
777             None,
778         );
779     }
780 
781     #[test]
782     #[should_panic]
preferred_address_v6_unspecified()783     fn preferred_address_v6_unspecified() {
784         let _ = PreferredAddress::new(
785             None,
786             Some(SocketAddr::new(IpAddr::V6(Ipv6Addr::from(0)), 443)),
787         );
788     }
789 
790     #[test]
791     #[should_panic]
preferred_address_v6_zero_port()792     fn preferred_address_v6_zero_port() {
793         let _ = PreferredAddress::new(
794             None,
795             Some(SocketAddr::new(IpAddr::V6(Ipv6Addr::from(1)), 0)),
796         );
797     }
798 
799     #[test]
800     #[should_panic]
preferred_address_v4_is_v6()801     fn preferred_address_v4_is_v6() {
802         let _ = PreferredAddress::new(
803             Some(SocketAddr::new(IpAddr::V6(Ipv6Addr::from(1)), 443)),
804             None,
805         );
806     }
807 
808     #[test]
809     #[should_panic]
preferred_address_v6_is_v4()810     fn preferred_address_v6_is_v4() {
811         let _ = PreferredAddress::new(
812             None,
813             Some(SocketAddr::new(
814                 IpAddr::V4(Ipv4Addr::from(0xc000_0201)),
815                 443,
816             )),
817         );
818     }
819 
820     #[test]
compatible_0rtt_ignored_values()821     fn compatible_0rtt_ignored_values() {
822         let mut tps_a = TransportParameters::default();
823         tps_a.set(
824             STATELESS_RESET_TOKEN,
825             TransportParameter::Bytes(vec![1, 2, 3]),
826         );
827         tps_a.set(IDLE_TIMEOUT, TransportParameter::Integer(10));
828         tps_a.set(MAX_ACK_DELAY, TransportParameter::Integer(22));
829         tps_a.set(ACTIVE_CONNECTION_ID_LIMIT, TransportParameter::Integer(33));
830 
831         let mut tps_b = TransportParameters::default();
832         assert!(tps_a.ok_for_0rtt(&tps_b));
833         assert!(tps_b.ok_for_0rtt(&tps_a));
834 
835         tps_b.set(
836             STATELESS_RESET_TOKEN,
837             TransportParameter::Bytes(vec![8, 9, 10]),
838         );
839         tps_b.set(IDLE_TIMEOUT, TransportParameter::Integer(100));
840         tps_b.set(MAX_ACK_DELAY, TransportParameter::Integer(2));
841         tps_b.set(ACTIVE_CONNECTION_ID_LIMIT, TransportParameter::Integer(44));
842         assert!(tps_a.ok_for_0rtt(&tps_b));
843         assert!(tps_b.ok_for_0rtt(&tps_a));
844     }
845 
846     #[test]
compatible_0rtt_integers()847     fn compatible_0rtt_integers() {
848         let mut tps_a = TransportParameters::default();
849         const INTEGER_KEYS: &[TransportParameterId] = &[
850             INITIAL_MAX_DATA,
851             INITIAL_MAX_STREAM_DATA_BIDI_LOCAL,
852             INITIAL_MAX_STREAM_DATA_BIDI_REMOTE,
853             INITIAL_MAX_STREAM_DATA_UNI,
854             INITIAL_MAX_STREAMS_BIDI,
855             INITIAL_MAX_STREAMS_UNI,
856             MAX_UDP_PAYLOAD_SIZE,
857             MIN_ACK_DELAY,
858         ];
859         for i in INTEGER_KEYS {
860             tps_a.set(*i, TransportParameter::Integer(12));
861         }
862 
863         let tps_b = tps_a.clone();
864         assert!(tps_a.ok_for_0rtt(&tps_b));
865         assert!(tps_b.ok_for_0rtt(&tps_a));
866 
867         // For each integer key, choose a new value that will be accepted.
868         for i in INTEGER_KEYS {
869             let mut tps_b = tps_a.clone();
870             // Set a safe new value; reducing MIN_ACK_DELAY instead.
871             let safe_value = if *i == MIN_ACK_DELAY { 11 } else { 13 };
872             tps_b.set(*i, TransportParameter::Integer(safe_value));
873             // If the new value is not safe relative to the remembered value,
874             // then we can't attempt 0-RTT with these parameters.
875             assert!(!tps_a.ok_for_0rtt(&tps_b));
876             // The opposite situation is fine.
877             assert!(tps_b.ok_for_0rtt(&tps_a));
878         }
879 
880         // Drop integer values and check that that is OK.
881         for i in INTEGER_KEYS {
882             let mut tps_b = tps_a.clone();
883             tps_b.remove(*i);
884             // A value that is missing from what is rememebered is OK.
885             assert!(tps_a.ok_for_0rtt(&tps_b));
886             // A value that is rememebered, but not current is not OK.
887             assert!(!tps_b.ok_for_0rtt(&tps_a));
888         }
889     }
890 
891     /// `ACTIVE_CONNECTION_ID_LIMIT` can't be less than 2.
892     #[test]
active_connection_id_limit_min_2()893     fn active_connection_id_limit_min_2() {
894         let mut tps = TransportParameters::default();
895 
896         // Intentionally set an invalid value for the ACTIVE_CONNECTION_ID_LIMIT transport parameter.
897         tps.params
898             .insert(ACTIVE_CONNECTION_ID_LIMIT, TransportParameter::Integer(1));
899 
900         let mut enc = Encoder::default();
901         tps.encode(&mut enc);
902 
903         // When decoding a set of transport parameters with an invalid ACTIVE_CONNECTION_ID_LIMIT
904         // the result should be an error.
905         let invalid_decode_result = TransportParameters::decode(&mut enc.as_decoder());
906         assert!(invalid_decode_result.is_err());
907     }
908 }
909