1 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4 // option. This file may not be copied, modified, or distributed
5 // except according to those terms.
6 
7 use std::cell::RefCell;
8 use std::cmp::{max, min};
9 use std::convert::TryFrom;
10 use std::mem;
11 use std::ops::{Index, IndexMut, Range};
12 use std::rc::Rc;
13 use std::time::Instant;
14 
15 use neqo_common::{hex, hex_snip_middle, qdebug, qinfo, qtrace, Encoder, Role};
16 
17 use neqo_crypto::{
18     hkdf, hp::HpKey, Aead, Agent, AntiReplay, Cipher, Epoch, Error as CryptoError, HandshakeState,
19     PrivateKey, PublicKey, Record, RecordList, ResumptionToken, SymKey, ZeroRttChecker,
20     TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256, TLS_CT_HANDSHAKE,
21     TLS_EPOCH_APPLICATION_DATA, TLS_EPOCH_HANDSHAKE, TLS_EPOCH_INITIAL, TLS_EPOCH_ZERO_RTT,
22     TLS_VERSION_1_3,
23 };
24 
25 use crate::packet::{PacketBuilder, PacketNumber, QuicVersion};
26 use crate::recovery::RecoveryToken;
27 use crate::recv_stream::RxStreamOrderer;
28 use crate::send_stream::TxBuffer;
29 use crate::stats::FrameStats;
30 use crate::tparams::{TpZeroRttChecker, TransportParameters, TransportParametersHandler};
31 use crate::tracking::PacketNumberSpace;
32 use crate::{Error, Res};
33 
34 const MAX_AUTH_TAG: usize = 32;
35 /// The number of invocations remaining on a write cipher before we try
36 /// to update keys.  This has to be much smaller than the number returned
37 /// by `CryptoDxState::limit` or updates will happen too often.  As we don't
38 /// need to ask permission to update, this can be quite small.
39 pub(crate) const UPDATE_WRITE_KEYS_AT: PacketNumber = 100;
40 
41 // This is a testing kludge that allows for overwriting the number of
42 // invocations of the next cipher to operate.  With this, it is possible
43 // to test what happens when the number of invocations reaches 0, or
44 // when it hits `UPDATE_WRITE_KEYS_AT` and an automatic update should occur.
45 // This is a little crude, but it saves a lot of plumbing.
46 #[cfg(test)]
47 thread_local!(pub(crate) static OVERWRITE_INVOCATIONS: RefCell<Option<PacketNumber>> = RefCell::default());
48 
49 #[derive(Debug)]
50 pub struct Crypto {
51     pub(crate) tls: Agent,
52     pub(crate) streams: CryptoStreams,
53     pub(crate) states: CryptoStates,
54 }
55 
56 type TpHandler = Rc<RefCell<TransportParametersHandler>>;
57 
58 impl Crypto {
new( version: QuicVersion, mut agent: Agent, protocols: &[impl AsRef<str>], tphandler: TpHandler, ) -> Res<Self>59     pub fn new(
60         version: QuicVersion,
61         mut agent: Agent,
62         protocols: &[impl AsRef<str>],
63         tphandler: TpHandler,
64     ) -> Res<Self> {
65         agent.set_version_range(TLS_VERSION_1_3, TLS_VERSION_1_3)?;
66         agent.set_ciphers(&[
67             TLS_AES_128_GCM_SHA256,
68             TLS_AES_256_GCM_SHA384,
69             TLS_CHACHA20_POLY1305_SHA256,
70         ])?;
71         agent.set_alpn(protocols)?;
72         agent.disable_end_of_early_data()?;
73         // Always enable 0-RTT on the client, but the server needs
74         // more configuration passed to server_enable_0rtt.
75         if let Agent::Client(c) = &mut agent {
76             c.enable_0rtt()?;
77         }
78         let extension = match version {
79             QuicVersion::Version1 => 0x39,
80             QuicVersion::Draft29
81             | QuicVersion::Draft30
82             | QuicVersion::Draft31
83             | QuicVersion::Draft32 => 0xffa5,
84         };
85         agent.extension_handler(extension, tphandler)?;
86         Ok(Self {
87             tls: agent,
88             streams: Default::default(),
89             states: Default::default(),
90         })
91     }
92 
server_enable_0rtt( &mut self, tphandler: TpHandler, anti_replay: &AntiReplay, zero_rtt_checker: impl ZeroRttChecker + 'static, ) -> Res<()>93     pub fn server_enable_0rtt(
94         &mut self,
95         tphandler: TpHandler,
96         anti_replay: &AntiReplay,
97         zero_rtt_checker: impl ZeroRttChecker + 'static,
98     ) -> Res<()> {
99         if let Agent::Server(s) = &mut self.tls {
100             Ok(s.enable_0rtt(
101                 anti_replay,
102                 0xffff_ffff,
103                 TpZeroRttChecker::wrap(tphandler, zero_rtt_checker),
104             )?)
105         } else {
106             panic!("not a server");
107         }
108     }
109 
server_enable_ech( &mut self, config: u8, public_name: &str, sk: &PrivateKey, pk: &PublicKey, ) -> Res<()>110     pub fn server_enable_ech(
111         &mut self,
112         config: u8,
113         public_name: &str,
114         sk: &PrivateKey,
115         pk: &PublicKey,
116     ) -> Res<()> {
117         if let Agent::Server(s) = &mut self.tls {
118             s.enable_ech(config, public_name, sk, pk)?;
119             Ok(())
120         } else {
121             panic!("not a client");
122         }
123     }
124 
client_enable_ech(&mut self, ech_config_list: impl AsRef<[u8]>) -> Res<()>125     pub fn client_enable_ech(&mut self, ech_config_list: impl AsRef<[u8]>) -> Res<()> {
126         if let Agent::Client(c) = &mut self.tls {
127             c.enable_ech(ech_config_list)?;
128             Ok(())
129         } else {
130             panic!("not a client");
131         }
132     }
133 
134     /// Get the active ECH configuration, which is empty if ECH is disabled.
ech_config(&self) -> &[u8]135     pub fn ech_config(&self) -> &[u8] {
136         self.tls.ech_config()
137     }
138 
handshake( &mut self, now: Instant, space: PacketNumberSpace, data: Option<&[u8]>, ) -> Res<&HandshakeState>139     pub fn handshake(
140         &mut self,
141         now: Instant,
142         space: PacketNumberSpace,
143         data: Option<&[u8]>,
144     ) -> Res<&HandshakeState> {
145         let input = data.map(|d| {
146             qtrace!("Handshake record received {:0x?} ", d);
147             let epoch = match space {
148                 PacketNumberSpace::Initial => TLS_EPOCH_INITIAL,
149                 PacketNumberSpace::Handshake => TLS_EPOCH_HANDSHAKE,
150                 // Our epoch progresses forward, but the TLS epoch is fixed to 3.
151                 PacketNumberSpace::ApplicationData => TLS_EPOCH_APPLICATION_DATA,
152             };
153             Record {
154                 ct: TLS_CT_HANDSHAKE,
155                 epoch,
156                 data: d.to_vec(),
157             }
158         });
159 
160         match self.tls.handshake_raw(now, input) {
161             Ok(output) => {
162                 self.buffer_records(output)?;
163                 Ok(self.tls.state())
164             }
165             Err(CryptoError::EchRetry(v)) => Err(Error::EchRetry(v)),
166             Err(e) => {
167                 qinfo!("Handshake failed {:?}", e);
168                 Err(match self.tls.alert() {
169                     Some(a) => Error::CryptoAlert(*a),
170                     _ => Error::CryptoError(e),
171                 })
172             }
173         }
174     }
175 
176     /// Enable 0-RTT and return `true` if it is enabled successfully.
enable_0rtt(&mut self, role: Role) -> Res<bool>177     pub fn enable_0rtt(&mut self, role: Role) -> Res<bool> {
178         let info = self.tls.preinfo()?;
179         // `info.early_data()` returns false for a server,
180         // so use `early_data_cipher()` to tell if 0-RTT is enabled.
181         let cipher = info.early_data_cipher();
182         if cipher.is_none() {
183             return Ok(false);
184         }
185         let (dir, secret) = match role {
186             Role::Client => (
187                 CryptoDxDirection::Write,
188                 self.tls.write_secret(TLS_EPOCH_ZERO_RTT),
189             ),
190             Role::Server => (
191                 CryptoDxDirection::Read,
192                 self.tls.read_secret(TLS_EPOCH_ZERO_RTT),
193             ),
194         };
195         let secret = secret.ok_or(Error::InternalError(1))?;
196         self.states.set_0rtt_keys(dir, &secret, cipher.unwrap());
197         Ok(true)
198     }
199 
200     /// Returns true if new handshake keys were installed.
install_keys(&mut self, role: Role) -> Res<bool>201     pub fn install_keys(&mut self, role: Role) -> Res<bool> {
202         if !self.tls.state().is_final() {
203             let installed_hs = self.install_handshake_keys()?;
204             if role == Role::Server {
205                 self.maybe_install_application_write_key()?;
206             }
207             Ok(installed_hs)
208         } else {
209             Ok(false)
210         }
211     }
212 
install_handshake_keys(&mut self) -> Res<bool>213     fn install_handshake_keys(&mut self) -> Res<bool> {
214         qtrace!([self], "Attempt to install handshake keys");
215         let write_secret = if let Some(secret) = self.tls.write_secret(TLS_EPOCH_HANDSHAKE) {
216             secret
217         } else {
218             // No keys is fine.
219             return Ok(false);
220         };
221         let read_secret = self
222             .tls
223             .read_secret(TLS_EPOCH_HANDSHAKE)
224             .ok_or(Error::InternalError(2))?;
225         let cipher = match self.tls.info() {
226             None => self.tls.preinfo()?.cipher_suite(),
227             Some(info) => Some(info.cipher_suite()),
228         }
229         .ok_or(Error::InternalError(3))?;
230         self.states
231             .set_handshake_keys(&write_secret, &read_secret, cipher);
232         qdebug!([self], "Handshake keys installed");
233         Ok(true)
234     }
235 
maybe_install_application_write_key(&mut self) -> Res<()>236     fn maybe_install_application_write_key(&mut self) -> Res<()> {
237         qtrace!([self], "Attempt to install application write key");
238         if let Some(secret) = self.tls.write_secret(TLS_EPOCH_APPLICATION_DATA) {
239             self.states.set_application_write_key(secret)?;
240             qdebug!([self], "Application write key installed");
241         }
242         Ok(())
243     }
244 
install_application_keys(&mut self, expire_0rtt: Instant) -> Res<()>245     pub fn install_application_keys(&mut self, expire_0rtt: Instant) -> Res<()> {
246         self.maybe_install_application_write_key()?;
247         // The write key might have been installed earlier, but it should
248         // always be installed now.
249         debug_assert!(self.states.app_write.is_some());
250         let read_secret = self
251             .tls
252             .read_secret(TLS_EPOCH_APPLICATION_DATA)
253             .ok_or(Error::InternalError(4))?;
254         self.states
255             .set_application_read_key(read_secret, expire_0rtt)?;
256         qdebug!([self], "application read keys installed");
257         Ok(())
258     }
259 
260     /// Buffer crypto records for sending.
buffer_records(&mut self, records: RecordList) -> Res<()>261     pub fn buffer_records(&mut self, records: RecordList) -> Res<()> {
262         for r in records {
263             if r.ct != TLS_CT_HANDSHAKE {
264                 return Err(Error::ProtocolViolation);
265             }
266             qtrace!([self], "Adding CRYPTO data {:?}", r);
267             self.streams.send(PacketNumberSpace::from(r.epoch), &r.data);
268         }
269         Ok(())
270     }
271 
write_frame( &mut self, space: PacketNumberSpace, builder: &mut PacketBuilder, tokens: &mut Vec<RecoveryToken>, stats: &mut FrameStats, ) -> Res<()>272     pub fn write_frame(
273         &mut self,
274         space: PacketNumberSpace,
275         builder: &mut PacketBuilder,
276         tokens: &mut Vec<RecoveryToken>,
277         stats: &mut FrameStats,
278     ) -> Res<()> {
279         self.streams.write_frame(space, builder, tokens, stats)
280     }
281 
acked(&mut self, token: &CryptoRecoveryToken)282     pub fn acked(&mut self, token: &CryptoRecoveryToken) {
283         qinfo!(
284             "Acked crypto frame space={} offset={} length={}",
285             token.space,
286             token.offset,
287             token.length
288         );
289         self.streams.acked(token);
290     }
291 
lost(&mut self, token: &CryptoRecoveryToken)292     pub fn lost(&mut self, token: &CryptoRecoveryToken) {
293         qinfo!(
294             "Lost crypto frame space={} offset={} length={}",
295             token.space,
296             token.offset,
297             token.length
298         );
299         self.streams.lost(token);
300     }
301 
302     /// Mark any outstanding frames in the indicated space as "lost" so
303     /// that they can be sent again.
resend_unacked(&mut self, space: PacketNumberSpace)304     pub fn resend_unacked(&mut self, space: PacketNumberSpace) {
305         self.streams.resend_unacked(space);
306     }
307 
308     /// Discard state for a packet number space and return true
309     /// if something was discarded.
discard(&mut self, space: PacketNumberSpace) -> bool310     pub fn discard(&mut self, space: PacketNumberSpace) -> bool {
311         self.streams.discard(space);
312         self.states.discard(space)
313     }
314 
create_resumption_token( &mut self, new_token: Option<&[u8]>, tps: &TransportParameters, rtt: u64, ) -> Option<ResumptionToken>315     pub fn create_resumption_token(
316         &mut self,
317         new_token: Option<&[u8]>,
318         tps: &TransportParameters,
319         rtt: u64,
320     ) -> Option<ResumptionToken> {
321         if let Agent::Client(ref mut c) = self.tls {
322             if let Some(ref t) = c.resumption_token() {
323                 qtrace!("TLS token {}", hex(t.as_ref()));
324                 let mut enc = Encoder::default();
325                 enc.encode_varint(rtt);
326                 enc.encode_vvec_with(|enc_inner| {
327                     tps.encode(enc_inner);
328                 });
329                 enc.encode_vvec(new_token.unwrap_or(&[]));
330                 enc.encode(t.as_ref());
331                 qinfo!("resumption token {}", hex_snip_middle(&enc[..]));
332                 Some(ResumptionToken::new(enc.into(), t.expiration_time()))
333             } else {
334                 None
335             }
336         } else {
337             unreachable!("It is a server.");
338         }
339     }
340 
has_resumption_token(&self) -> bool341     pub fn has_resumption_token(&self) -> bool {
342         if let Agent::Client(c) = &self.tls {
343             c.has_resumption_token()
344         } else {
345             unreachable!("It is a server.");
346         }
347     }
348 }
349 
350 impl ::std::fmt::Display for Crypto {
fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result351     fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
352         write!(f, "Crypto")
353     }
354 }
355 
356 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
357 pub enum CryptoDxDirection {
358     Read,
359     Write,
360 }
361 
362 #[derive(Debug)]
363 pub struct CryptoDxState {
364     direction: CryptoDxDirection,
365     /// The epoch of this crypto state.  This initially tracks TLS epochs
366     /// via DTLS: 0 = initial, 1 = 0-RTT, 2 = handshake, 3 = application.
367     /// But we don't need to keep that, and QUIC isn't limited in how
368     /// many times keys can be updated, so we don't use `u16` for this.
369     epoch: usize,
370     aead: Aead,
371     hpkey: HpKey,
372     /// This tracks the range of packet numbers that have been seen.  This allows
373     /// for verifying that packet numbers before a key update are strictly lower
374     /// than packet numbers after a key update.
375     used_pn: Range<PacketNumber>,
376     /// This is the minimum packet number that is allowed.
377     min_pn: PacketNumber,
378     /// The total number of operations that are remaining before the keys
379     /// become exhausted and can't be used any more.
380     invocations: PacketNumber,
381 }
382 
383 impl CryptoDxState {
384     #[allow(unknown_lints, renamed_and_removed_lints, clippy::unknown_clippy_lints)] // Until we require rust 1.45.
385     #[allow(clippy::reversed_empty_ranges)] // To initialize an empty range.
new( direction: CryptoDxDirection, epoch: Epoch, secret: &SymKey, cipher: Cipher, ) -> Self386     pub fn new(
387         direction: CryptoDxDirection,
388         epoch: Epoch,
389         secret: &SymKey,
390         cipher: Cipher,
391     ) -> Self {
392         qinfo!(
393             "Making {:?} {} CryptoDxState, cipher={}",
394             direction,
395             epoch,
396             cipher
397         );
398         Self {
399             direction,
400             epoch: usize::from(epoch),
401             aead: Aead::new(TLS_VERSION_1_3, cipher, secret, "quic ").unwrap(),
402             hpkey: HpKey::extract(TLS_VERSION_1_3, cipher, secret, "quic hp").unwrap(),
403             used_pn: 0..0,
404             min_pn: 0,
405             invocations: Self::limit(direction, cipher),
406         }
407     }
408 
new_initial( quic_version: QuicVersion, direction: CryptoDxDirection, label: &str, dcid: &[u8], ) -> Self409     pub fn new_initial(
410         quic_version: QuicVersion,
411         direction: CryptoDxDirection,
412         label: &str,
413         dcid: &[u8],
414     ) -> Self {
415         const INITIAL_SALT_V1: &[u8] = &[
416             0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8,
417             0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a,
418         ];
419         const INITIAL_SALT_29_32: &[u8] = &[
420             0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61,
421             0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99,
422         ];
423         qtrace!("new_initial for {:?}", quic_version);
424         let salt = match quic_version {
425             QuicVersion::Version1 => INITIAL_SALT_V1,
426             QuicVersion::Draft29
427             | QuicVersion::Draft30
428             | QuicVersion::Draft31
429             | QuicVersion::Draft32 => INITIAL_SALT_29_32,
430         };
431         let cipher = TLS_AES_128_GCM_SHA256;
432         let initial_secret = hkdf::extract(
433             TLS_VERSION_1_3,
434             cipher,
435             Some(hkdf::import_key(TLS_VERSION_1_3, salt).as_ref().unwrap()),
436             hkdf::import_key(TLS_VERSION_1_3, dcid).as_ref().unwrap(),
437         )
438         .unwrap();
439 
440         let secret =
441             hkdf::expand_label(TLS_VERSION_1_3, cipher, &initial_secret, &[], label).unwrap();
442 
443         Self::new(direction, TLS_EPOCH_INITIAL, &secret, cipher)
444     }
445 
446     /// Determine the confidentiality and integrity limits for the cipher.
limit(direction: CryptoDxDirection, cipher: Cipher) -> PacketNumber447     fn limit(direction: CryptoDxDirection, cipher: Cipher) -> PacketNumber {
448         match direction {
449             // This uses the smaller limits for 2^16 byte packets
450             // as we don't control incoming packet size.
451             CryptoDxDirection::Read => match cipher {
452                 TLS_AES_128_GCM_SHA256 => 1 << 52,
453                 TLS_AES_256_GCM_SHA384 => PacketNumber::MAX,
454                 TLS_CHACHA20_POLY1305_SHA256 => 1 << 36,
455                 _ => unreachable!(),
456             },
457             // This uses the larger limits for 2^11 byte packets.
458             CryptoDxDirection::Write => match cipher {
459                 TLS_AES_128_GCM_SHA256 | TLS_AES_256_GCM_SHA384 => 1 << 28,
460                 TLS_CHACHA20_POLY1305_SHA256 => PacketNumber::MAX,
461                 _ => unreachable!(),
462             },
463         }
464     }
465 
invoked(&mut self) -> Res<()>466     fn invoked(&mut self) -> Res<()> {
467         #[cfg(test)]
468         OVERWRITE_INVOCATIONS.with(|v| {
469             if let Some(i) = v.borrow_mut().take() {
470                 neqo_common::qwarn!("Setting {:?} invocations to {}", self.direction, i);
471                 self.invocations = i;
472             }
473         });
474         self.invocations = self
475             .invocations
476             .checked_sub(1)
477             .ok_or(Error::KeysExhausted)?;
478         Ok(())
479     }
480 
481     /// Determine whether we should initiate a key update.
should_update(&self) -> bool482     pub fn should_update(&self) -> bool {
483         // There is no point in updating read keys as the limit is global.
484         debug_assert_eq!(self.direction, CryptoDxDirection::Write);
485         self.invocations <= UPDATE_WRITE_KEYS_AT
486     }
487 
next(&self, next_secret: &SymKey, cipher: Cipher) -> Self488     pub fn next(&self, next_secret: &SymKey, cipher: Cipher) -> Self {
489         let pn = self.next_pn();
490         // We count invocations of each write key just for that key, but all
491         // attempts to invocations to read count toward a single limit.
492         // This doesn't count use of Handshake keys.
493         let invocations = if self.direction == CryptoDxDirection::Read {
494             self.invocations
495         } else {
496             Self::limit(CryptoDxDirection::Write, cipher)
497         };
498         Self {
499             direction: self.direction,
500             epoch: self.epoch + 1,
501             aead: Aead::new(TLS_VERSION_1_3, cipher, next_secret, "quic ").unwrap(),
502             hpkey: self.hpkey.clone(),
503             used_pn: pn..pn,
504             min_pn: pn,
505             invocations,
506         }
507     }
508 
509     #[must_use]
key_phase(&self) -> bool510     pub fn key_phase(&self) -> bool {
511         // Epoch 3 => 0, 4 => 1, 5 => 0, 6 => 1, ...
512         self.epoch & 1 != 1
513     }
514 
515     /// This is a continuation of a previous, so adjust the range accordingly.
516     /// Fail if the two ranges overlap.  Do nothing if the directions don't match.
continuation(&mut self, prev: &Self) -> Res<()>517     pub fn continuation(&mut self, prev: &Self) -> Res<()> {
518         debug_assert_eq!(self.direction, prev.direction);
519         let next = prev.next_pn();
520         self.min_pn = next;
521         // TODO(mt) use Range::is_empty() when available
522         if self.used_pn.start == self.used_pn.end {
523             self.used_pn = next..next;
524             Ok(())
525         } else if prev.used_pn.end > self.used_pn.start {
526             qdebug!(
527                 [self],
528                 "Found packet with too new packet number {} > {}, compared to {}",
529                 self.used_pn.start,
530                 prev.used_pn.end,
531                 prev,
532             );
533             Err(Error::PacketNumberOverlap)
534         } else {
535             self.used_pn.start = next;
536             Ok(())
537         }
538     }
539 
540     /// Mark a packet number as used.  If this is too low, reject it.
541     /// Note that this won't catch a value that is too high if packets protected with
542     /// old keys are received after a key update.  That needs to be caught elsewhere.
used(&mut self, pn: PacketNumber) -> Res<()>543     pub fn used(&mut self, pn: PacketNumber) -> Res<()> {
544         if pn < self.min_pn {
545             qdebug!(
546                 [self],
547                 "Found packet with too old packet number: {} < {}",
548                 pn,
549                 self.min_pn
550             );
551             return Err(Error::PacketNumberOverlap);
552         }
553         if self.used_pn.start == self.used_pn.end {
554             self.used_pn.start = pn;
555         }
556         self.used_pn.end = max(pn + 1, self.used_pn.end);
557         Ok(())
558     }
559 
560     #[must_use]
needs_update(&self) -> bool561     pub fn needs_update(&self) -> bool {
562         // Only initiate a key update if we have processed exactly one packet
563         // and we are in an epoch greater than 3.
564         self.used_pn.start + 1 == self.used_pn.end
565             && self.epoch > usize::from(TLS_EPOCH_APPLICATION_DATA)
566     }
567 
568     #[must_use]
can_update(&self, largest_acknowledged: Option<PacketNumber>) -> bool569     pub fn can_update(&self, largest_acknowledged: Option<PacketNumber>) -> bool {
570         if let Some(la) = largest_acknowledged {
571             self.used_pn.contains(&la)
572         } else {
573             // If we haven't received any acknowledgments, it's OK to update
574             // the first application data epoch.
575             self.epoch == usize::from(TLS_EPOCH_APPLICATION_DATA)
576         }
577     }
578 
compute_mask(&self, sample: &[u8]) -> Res<Vec<u8>>579     pub fn compute_mask(&self, sample: &[u8]) -> Res<Vec<u8>> {
580         let mask = self.hpkey.mask(sample)?;
581         qtrace!([self], "HP sample={} mask={}", hex(sample), hex(&mask));
582         Ok(mask)
583     }
584 
585     #[must_use]
next_pn(&self) -> PacketNumber586     pub fn next_pn(&self) -> PacketNumber {
587         self.used_pn.end
588     }
589 
encrypt(&mut self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res<Vec<u8>>590     pub fn encrypt(&mut self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res<Vec<u8>> {
591         debug_assert_eq!(self.direction, CryptoDxDirection::Write);
592         qtrace!(
593             [self],
594             "encrypt pn={} hdr={} body={}",
595             pn,
596             hex(hdr),
597             hex(body)
598         );
599         // The numbers in `Self::limit` assume a maximum packet size of 2^11.
600         if body.len() > 2048 {
601             debug_assert!(false);
602             return Err(Error::InternalError(12));
603         }
604         self.invoked()?;
605 
606         let size = body.len() + MAX_AUTH_TAG;
607         let mut out = vec![0; size];
608         let res = self.aead.encrypt(pn, hdr, body, &mut out)?;
609 
610         qtrace!([self], "encrypt ct={}", hex(res));
611         debug_assert_eq!(pn, self.next_pn());
612         self.used(pn)?;
613         Ok(res.to_vec())
614     }
615 
616     #[must_use]
expansion(&self) -> usize617     pub fn expansion(&self) -> usize {
618         self.aead.expansion()
619     }
620 
decrypt(&mut self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res<Vec<u8>>621     pub fn decrypt(&mut self, pn: PacketNumber, hdr: &[u8], body: &[u8]) -> Res<Vec<u8>> {
622         debug_assert_eq!(self.direction, CryptoDxDirection::Read);
623         qtrace!(
624             [self],
625             "decrypt pn={} hdr={} body={}",
626             pn,
627             hex(hdr),
628             hex(body)
629         );
630         self.invoked()?;
631         let mut out = vec![0; body.len()];
632         let res = self.aead.decrypt(pn, hdr, body, &mut out)?;
633         self.used(pn)?;
634         Ok(res.to_vec())
635     }
636 
637     #[cfg(all(test, not(feature = "fuzzing")))]
test_default() -> Self638     pub(crate) fn test_default() -> Self {
639         // This matches the value in packet.rs
640         const CLIENT_CID: &[u8] = &[0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08];
641         Self::new_initial(
642             QuicVersion::default(),
643             CryptoDxDirection::Write,
644             "server in",
645             CLIENT_CID,
646         )
647     }
648 
649     /// Get the amount of extra padding packets protected with this profile need.
650     /// This is the difference between the size of the header protection sample
651     /// and the AEAD expansion.
extra_padding(&self) -> usize652     pub fn extra_padding(&self) -> usize {
653         self.hpkey
654             .sample_size()
655             .saturating_sub(self.aead.expansion())
656     }
657 }
658 
659 impl std::fmt::Display for CryptoDxState {
fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result660     fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
661         write!(f, "epoch {} {:?}", self.epoch, self.direction)
662     }
663 }
664 
665 #[derive(Debug)]
666 pub struct CryptoState {
667     tx: CryptoDxState,
668     rx: CryptoDxState,
669 }
670 
671 impl Index<CryptoDxDirection> for CryptoState {
672     type Output = CryptoDxState;
673 
index(&self, dir: CryptoDxDirection) -> &Self::Output674     fn index(&self, dir: CryptoDxDirection) -> &Self::Output {
675         match dir {
676             CryptoDxDirection::Read => &self.rx,
677             CryptoDxDirection::Write => &self.tx,
678         }
679     }
680 }
681 
682 impl IndexMut<CryptoDxDirection> for CryptoState {
index_mut(&mut self, dir: CryptoDxDirection) -> &mut Self::Output683     fn index_mut(&mut self, dir: CryptoDxDirection) -> &mut Self::Output {
684         match dir {
685             CryptoDxDirection::Read => &mut self.rx,
686             CryptoDxDirection::Write => &mut self.tx,
687         }
688     }
689 }
690 
691 /// `CryptoDxAppData` wraps the state necessary for one direction of application data keys.
692 /// This includes the secret needed to generate the next set of keys.
693 #[derive(Debug)]
694 pub(crate) struct CryptoDxAppData {
695     dx: CryptoDxState,
696     cipher: Cipher,
697     // Not the secret used to create `self.dx`, but the one needed for the next iteration.
698     next_secret: SymKey,
699 }
700 
701 impl CryptoDxAppData {
new(dir: CryptoDxDirection, secret: SymKey, cipher: Cipher) -> Res<Self>702     pub fn new(dir: CryptoDxDirection, secret: SymKey, cipher: Cipher) -> Res<Self> {
703         Ok(Self {
704             dx: CryptoDxState::new(dir, TLS_EPOCH_APPLICATION_DATA, &secret, cipher),
705             cipher,
706             next_secret: Self::update_secret(cipher, &secret)?,
707         })
708     }
709 
update_secret(cipher: Cipher, secret: &SymKey) -> Res<SymKey>710     fn update_secret(cipher: Cipher, secret: &SymKey) -> Res<SymKey> {
711         let next = hkdf::expand_label(TLS_VERSION_1_3, cipher, secret, &[], "quic ku")?;
712         Ok(next)
713     }
714 
next(&self) -> Res<Self>715     pub fn next(&self) -> Res<Self> {
716         if self.dx.epoch == usize::max_value() {
717             // Guard against too many key updates.
718             return Err(Error::KeysExhausted);
719         }
720         let next_secret = Self::update_secret(self.cipher, &self.next_secret)?;
721         Ok(Self {
722             dx: self.dx.next(&self.next_secret, self.cipher),
723             cipher: self.cipher,
724             next_secret,
725         })
726     }
727 
epoch(&self) -> usize728     pub fn epoch(&self) -> usize {
729         self.dx.epoch
730     }
731 }
732 
733 #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
734 pub enum CryptoSpace {
735     Initial,
736     ZeroRtt,
737     Handshake,
738     ApplicationData,
739 }
740 
741 #[derive(Debug, Default)]
742 pub struct CryptoStates {
743     initial: Option<CryptoState>,
744     handshake: Option<CryptoState>,
745     zero_rtt: Option<CryptoDxState>, // One direction only!
746     cipher: Cipher,
747     app_write: Option<CryptoDxAppData>,
748     app_read: Option<CryptoDxAppData>,
749     app_read_next: Option<CryptoDxAppData>,
750     // If this is set, then we have noticed a genuine update.
751     // Once this time passes, we should switch in new keys.
752     read_update_time: Option<Instant>,
753 }
754 
755 impl CryptoStates {
756     /// Select a `CryptoDxState` and `CryptoSpace` for the given `PacketNumberSpace`.
757     /// This selects 0-RTT keys for `PacketNumberSpace::ApplicationData` if 1-RTT keys are
758     /// not yet available.
select_tx( &mut self, space: PacketNumberSpace, ) -> Option<(CryptoSpace, &mut CryptoDxState)>759     pub fn select_tx(
760         &mut self,
761         space: PacketNumberSpace,
762     ) -> Option<(CryptoSpace, &mut CryptoDxState)> {
763         match space {
764             PacketNumberSpace::Initial => self
765                 .tx(CryptoSpace::Initial)
766                 .map(|dx| (CryptoSpace::Initial, dx)),
767             PacketNumberSpace::Handshake => self
768                 .tx(CryptoSpace::Handshake)
769                 .map(|dx| (CryptoSpace::Handshake, dx)),
770             PacketNumberSpace::ApplicationData => {
771                 if let Some(app) = self.app_write.as_mut() {
772                     Some((CryptoSpace::ApplicationData, &mut app.dx))
773                 } else {
774                     self.zero_rtt.as_mut().map(|dx| (CryptoSpace::ZeroRtt, dx))
775                 }
776             }
777         }
778     }
779 
tx<'a>(&'a mut self, cspace: CryptoSpace) -> Option<&'a mut CryptoDxState>780     pub fn tx<'a>(&'a mut self, cspace: CryptoSpace) -> Option<&'a mut CryptoDxState> {
781         let tx = |k: Option<&'a mut CryptoState>| k.map(|dx| &mut dx.tx);
782         match cspace {
783             CryptoSpace::Initial => tx(self.initial.as_mut()),
784             CryptoSpace::ZeroRtt => self
785                 .zero_rtt
786                 .as_mut()
787                 .filter(|z| z.direction == CryptoDxDirection::Write),
788             CryptoSpace::Handshake => tx(self.handshake.as_mut()),
789             CryptoSpace::ApplicationData => self.app_write.as_mut().map(|app| &mut app.dx),
790         }
791     }
792 
rx_hp(&mut self, cspace: CryptoSpace) -> Option<&mut CryptoDxState>793     pub fn rx_hp(&mut self, cspace: CryptoSpace) -> Option<&mut CryptoDxState> {
794         if let CryptoSpace::ApplicationData = cspace {
795             self.app_read.as_mut().map(|ar| &mut ar.dx)
796         } else {
797             self.rx(cspace, false)
798         }
799     }
800 
rx<'a>( &'a mut self, cspace: CryptoSpace, key_phase: bool, ) -> Option<&'a mut CryptoDxState>801     pub fn rx<'a>(
802         &'a mut self,
803         cspace: CryptoSpace,
804         key_phase: bool,
805     ) -> Option<&'a mut CryptoDxState> {
806         let rx = |x: Option<&'a mut CryptoState>| x.map(|dx| &mut dx.rx);
807         match cspace {
808             CryptoSpace::Initial => rx(self.initial.as_mut()),
809             CryptoSpace::ZeroRtt => self
810                 .zero_rtt
811                 .as_mut()
812                 .filter(|z| z.direction == CryptoDxDirection::Read),
813             CryptoSpace::Handshake => rx(self.handshake.as_mut()),
814             CryptoSpace::ApplicationData => {
815                 let f = |a: Option<&'a mut CryptoDxAppData>| {
816                     a.filter(|ar| ar.dx.key_phase() == key_phase)
817                 };
818                 // XOR to reduce the leakage about which key is chosen.
819                 f(self.app_read.as_mut())
820                     .xor(f(self.app_read_next.as_mut()))
821                     .map(|ar| &mut ar.dx)
822             }
823         }
824     }
825 
826     /// Whether keys for processing packets in the indicated space are pending.
827     /// This allows the caller to determine whether to save a packet for later
828     /// when keys are not available.
829     /// NOTE: 0-RTT keys are not considered here.  The expectation is that a
830     /// server will have to save 0-RTT packets in a different place.  Though it
831     /// is possible to attribute 0-RTT packets to an existing connection if there
832     /// is a multi-packet Initial, that is an unusual circumstance, so we
833     /// don't do caching for that in those places that call this function.
rx_pending(&self, space: CryptoSpace) -> bool834     pub fn rx_pending(&self, space: CryptoSpace) -> bool {
835         match space {
836             CryptoSpace::Initial | CryptoSpace::ZeroRtt => false,
837             CryptoSpace::Handshake => self.handshake.is_none() && self.initial.is_some(),
838             CryptoSpace::ApplicationData => self.app_read.is_none(),
839         }
840     }
841 
842     /// Create the initial crypto state.
init(&mut self, quic_version: QuicVersion, role: Role, dcid: &[u8])843     pub fn init(&mut self, quic_version: QuicVersion, role: Role, dcid: &[u8]) {
844         const CLIENT_INITIAL_LABEL: &str = "client in";
845         const SERVER_INITIAL_LABEL: &str = "server in";
846 
847         qinfo!(
848             [self],
849             "Creating initial cipher state role={:?} dcid={}",
850             role,
851             hex(dcid)
852         );
853 
854         let (write, read) = match role {
855             Role::Client => (CLIENT_INITIAL_LABEL, SERVER_INITIAL_LABEL),
856             Role::Server => (SERVER_INITIAL_LABEL, CLIENT_INITIAL_LABEL),
857         };
858 
859         let mut initial = CryptoState {
860             tx: CryptoDxState::new_initial(quic_version, CryptoDxDirection::Write, write, dcid),
861             rx: CryptoDxState::new_initial(quic_version, CryptoDxDirection::Read, read, dcid),
862         };
863         if let Some(prev) = &self.initial {
864             qinfo!(
865                 [self],
866                 "Continue packet numbers for initial after retry (write is {:?})",
867                 prev.rx.used_pn,
868             );
869             initial.tx.continuation(&prev.tx).unwrap();
870         }
871         self.initial = Some(initial);
872     }
873 
set_0rtt_keys(&mut self, dir: CryptoDxDirection, secret: &SymKey, cipher: Cipher)874     pub fn set_0rtt_keys(&mut self, dir: CryptoDxDirection, secret: &SymKey, cipher: Cipher) {
875         qtrace!([self], "install 0-RTT keys");
876         self.zero_rtt = Some(CryptoDxState::new(dir, TLS_EPOCH_ZERO_RTT, secret, cipher));
877     }
878 
879     /// Discard keys and return true if that happened.
discard(&mut self, space: PacketNumberSpace) -> bool880     pub fn discard(&mut self, space: PacketNumberSpace) -> bool {
881         match space {
882             PacketNumberSpace::Initial => self.initial.take().is_some(),
883             PacketNumberSpace::Handshake => self.handshake.take().is_some(),
884             PacketNumberSpace::ApplicationData => panic!("Can't drop application data keys"),
885         }
886     }
887 
discard_0rtt_keys(&mut self)888     pub fn discard_0rtt_keys(&mut self) {
889         qtrace!([self], "discard 0-RTT keys");
890         assert!(
891             self.app_read.is_none(),
892             "Can't discard 0-RTT after setting application keys"
893         );
894         self.zero_rtt = None;
895     }
896 
set_handshake_keys( &mut self, write_secret: &SymKey, read_secret: &SymKey, cipher: Cipher, )897     pub fn set_handshake_keys(
898         &mut self,
899         write_secret: &SymKey,
900         read_secret: &SymKey,
901         cipher: Cipher,
902     ) {
903         self.cipher = cipher;
904         self.handshake = Some(CryptoState {
905             tx: CryptoDxState::new(
906                 CryptoDxDirection::Write,
907                 TLS_EPOCH_HANDSHAKE,
908                 write_secret,
909                 cipher,
910             ),
911             rx: CryptoDxState::new(
912                 CryptoDxDirection::Read,
913                 TLS_EPOCH_HANDSHAKE,
914                 read_secret,
915                 cipher,
916             ),
917         });
918     }
919 
set_application_write_key(&mut self, secret: SymKey) -> Res<()>920     pub fn set_application_write_key(&mut self, secret: SymKey) -> Res<()> {
921         debug_assert!(self.app_write.is_none());
922         debug_assert_ne!(self.cipher, 0);
923         let mut app = CryptoDxAppData::new(CryptoDxDirection::Write, secret, self.cipher)?;
924         if let Some(z) = &self.zero_rtt {
925             if z.direction == CryptoDxDirection::Write {
926                 app.dx.continuation(z)?;
927             }
928         }
929         self.zero_rtt = None;
930         self.app_write = Some(app);
931         Ok(())
932     }
933 
set_application_read_key(&mut self, secret: SymKey, expire_0rtt: Instant) -> Res<()>934     pub fn set_application_read_key(&mut self, secret: SymKey, expire_0rtt: Instant) -> Res<()> {
935         debug_assert!(self.app_write.is_some(), "should have write keys installed");
936         debug_assert!(self.app_read.is_none());
937         let mut app = CryptoDxAppData::new(CryptoDxDirection::Read, secret, self.cipher)?;
938         if let Some(z) = &self.zero_rtt {
939             if z.direction == CryptoDxDirection::Read {
940                 app.dx.continuation(z)?;
941             }
942             self.read_update_time = Some(expire_0rtt);
943         }
944         self.app_read_next = Some(app.next()?);
945         self.app_read = Some(app);
946         Ok(())
947     }
948 
949     /// Update the write keys.
initiate_key_update(&mut self, largest_acknowledged: Option<PacketNumber>) -> Res<()>950     pub fn initiate_key_update(&mut self, largest_acknowledged: Option<PacketNumber>) -> Res<()> {
951         // Only update if we are able to. We can only do this if we have
952         // received an acknowledgement for a packet in the current phase.
953         // Also, skip this if we are waiting for read keys on the existing
954         // key update to be rolled over.
955         let write = &self.app_write.as_ref().unwrap().dx;
956         if write.can_update(largest_acknowledged) && self.read_update_time.is_none() {
957             // This call additionally checks that we don't advance to the next
958             // epoch while a key update is in progress.
959             if self.maybe_update_write()? {
960                 Ok(())
961             } else {
962                 qdebug!([self], "Write keys already updated");
963                 Err(Error::KeyUpdateBlocked)
964             }
965         } else {
966             qdebug!([self], "Waiting for ACK or blocked on read key timer");
967             Err(Error::KeyUpdateBlocked)
968         }
969     }
970 
971     /// Try to update, and return true if it happened.
maybe_update_write(&mut self) -> Res<bool>972     fn maybe_update_write(&mut self) -> Res<bool> {
973         // Update write keys.  But only do so if the write keys are not already
974         // ahead of the read keys.  If we initiated the key update, the write keys
975         // will already be ahead.
976         debug_assert!(self.read_update_time.is_none());
977         let write = &self.app_write.as_ref().unwrap();
978         let read = &self.app_read.as_ref().unwrap();
979         if write.epoch() == read.epoch() {
980             qdebug!([self], "Update write keys to epoch={}", write.epoch() + 1);
981             self.app_write = Some(write.next()?);
982             Ok(true)
983         } else {
984             Ok(false)
985         }
986     }
987 
988     /// Check whether write keys are close to running out of invocations.
989     /// If that is close, update them if possible.  Failing to update at
990     /// this stage is cause for a fatal error.
auto_update(&mut self) -> Res<()>991     pub fn auto_update(&mut self) -> Res<()> {
992         if let Some(app_write) = self.app_write.as_ref() {
993             if app_write.dx.should_update() {
994                 qinfo!([self], "Initiating automatic key update");
995                 if !self.maybe_update_write()? {
996                     return Err(Error::KeysExhausted);
997                 }
998             }
999         }
1000         Ok(())
1001     }
1002 
has_0rtt_read(&self) -> bool1003     fn has_0rtt_read(&self) -> bool {
1004         self.zero_rtt
1005             .as_ref()
1006             .filter(|z| z.direction == CryptoDxDirection::Read)
1007             .is_some()
1008     }
1009 
1010     /// Prepare to update read keys.  This doesn't happen immediately as
1011     /// we want to ensure that we can continue to receive any delayed
1012     /// packets that use the old keys.  So we just set a timer.
key_update_received(&mut self, expiration: Instant) -> Res<()>1013     pub fn key_update_received(&mut self, expiration: Instant) -> Res<()> {
1014         qtrace!([self], "Key update received");
1015         // If we received a key update, then we assume that the peer has
1016         // acknowledged a packet we sent in this epoch. It's OK to do that
1017         // because they aren't allowed to update without first having received
1018         // something from us. If the ACK isn't in the packet that triggered this
1019         // key update, it must be in some other packet they have sent.
1020         let _ = self.maybe_update_write()?;
1021 
1022         // We shouldn't have 0-RTT keys at this point, but if we do, dump them.
1023         debug_assert_eq!(self.read_update_time.is_some(), self.has_0rtt_read());
1024         if self.has_0rtt_read() {
1025             self.zero_rtt = None;
1026         }
1027         self.read_update_time = Some(expiration);
1028         Ok(())
1029     }
1030 
1031     #[must_use]
update_time(&self) -> Option<Instant>1032     pub fn update_time(&self) -> Option<Instant> {
1033         self.read_update_time
1034     }
1035 
1036     /// Check if time has passed for updating key update parameters.
1037     /// If it has, then swap keys over and allow more key updates to be initiated.
1038     /// This is also used to discard 0-RTT read keys at the server in the same way.
check_key_update(&mut self, now: Instant) -> Res<()>1039     pub fn check_key_update(&mut self, now: Instant) -> Res<()> {
1040         if let Some(expiry) = self.read_update_time {
1041             // If enough time has passed, then install new keys and clear the timer.
1042             if now >= expiry {
1043                 if self.has_0rtt_read() {
1044                     qtrace!([self], "Discarding 0-RTT keys");
1045                     self.zero_rtt = None;
1046                 } else {
1047                     qtrace!([self], "Rotating read keys");
1048                     mem::swap(&mut self.app_read, &mut self.app_read_next);
1049                     self.app_read_next = Some(self.app_read.as_ref().unwrap().next()?);
1050                 }
1051                 self.read_update_time = None;
1052             }
1053         }
1054         Ok(())
1055     }
1056 
1057     /// Get the current/highest epoch.  This returns (write, read) epochs.
1058     #[cfg(test)]
get_epochs(&self) -> (Option<usize>, Option<usize>)1059     pub fn get_epochs(&self) -> (Option<usize>, Option<usize>) {
1060         let to_epoch = |app: &Option<CryptoDxAppData>| app.as_ref().map(|a| a.dx.epoch);
1061         (to_epoch(&self.app_write), to_epoch(&self.app_read))
1062     }
1063 
1064     /// While we are awaiting the completion of a key update, we might receive
1065     /// valid packets that are protected with old keys. We need to ensure that
1066     /// these don't carry packet numbers higher than those in packets protected
1067     /// with the newer keys.  To ensure that, this is called after every decryption.
check_pn_overlap(&mut self) -> Res<()>1068     pub fn check_pn_overlap(&mut self) -> Res<()> {
1069         // We only need to do the check while we are waiting for read keys to be updated.
1070         if self.read_update_time.is_some() {
1071             qtrace!([self], "Checking for PN overlap");
1072             let next_dx = &mut self.app_read_next.as_mut().unwrap().dx;
1073             next_dx.continuation(&self.app_read.as_ref().unwrap().dx)?;
1074         }
1075         Ok(())
1076     }
1077 
1078     /// Make some state for removing protection in tests.
1079     #[cfg(not(feature = "fuzzing"))]
1080     #[cfg(test)]
test_default() -> Self1081     pub(crate) fn test_default() -> Self {
1082         let read = |epoch| {
1083             let mut dx = CryptoDxState::test_default();
1084             dx.direction = CryptoDxDirection::Read;
1085             dx.epoch = epoch;
1086             dx
1087         };
1088         let app_read = |epoch| CryptoDxAppData {
1089             dx: read(epoch),
1090             cipher: TLS_AES_128_GCM_SHA256,
1091             next_secret: hkdf::import_key(TLS_VERSION_1_3, &[0xaa; 32]).unwrap(),
1092         };
1093         Self {
1094             initial: Some(CryptoState {
1095                 tx: CryptoDxState::test_default(),
1096                 rx: read(0),
1097             }),
1098             handshake: None,
1099             zero_rtt: None,
1100             cipher: TLS_AES_128_GCM_SHA256,
1101             // This isn't used, but the epoch is read to check for a key update.
1102             app_write: Some(app_read(3)),
1103             app_read: Some(app_read(3)),
1104             app_read_next: Some(app_read(4)),
1105             read_update_time: None,
1106         }
1107     }
1108 
1109     #[cfg(all(not(feature = "fuzzing"), test))]
test_chacha() -> Self1110     pub(crate) fn test_chacha() -> Self {
1111         const SECRET: &[u8] = &[
1112             0x9a, 0xc3, 0x12, 0xa7, 0xf8, 0x77, 0x46, 0x8e, 0xbe, 0x69, 0x42, 0x27, 0x48, 0xad,
1113             0x00, 0xa1, 0x54, 0x43, 0xf1, 0x82, 0x03, 0xa0, 0x7d, 0x60, 0x60, 0xf6, 0x88, 0xf3,
1114             0x0f, 0x21, 0x63, 0x2b,
1115         ];
1116         let secret = hkdf::import_key(TLS_VERSION_1_3, SECRET).unwrap();
1117         let app_read = |epoch| CryptoDxAppData {
1118             dx: CryptoDxState {
1119                 direction: CryptoDxDirection::Read,
1120                 epoch,
1121                 aead: Aead::new(
1122                     TLS_VERSION_1_3,
1123                     TLS_CHACHA20_POLY1305_SHA256,
1124                     &secret,
1125                     "quic ",
1126                 )
1127                 .unwrap(),
1128                 hpkey: HpKey::extract(
1129                     TLS_VERSION_1_3,
1130                     TLS_CHACHA20_POLY1305_SHA256,
1131                     &secret,
1132                     "quic hp",
1133                 )
1134                 .unwrap(),
1135                 used_pn: 0..645_971_972,
1136                 min_pn: 0,
1137                 invocations: 10,
1138             },
1139             cipher: TLS_CHACHA20_POLY1305_SHA256,
1140             next_secret: secret.clone(),
1141         };
1142         Self {
1143             initial: None,
1144             handshake: None,
1145             zero_rtt: None,
1146             cipher: TLS_CHACHA20_POLY1305_SHA256,
1147             app_write: Some(app_read(3)),
1148             app_read: Some(app_read(3)),
1149             app_read_next: Some(app_read(4)),
1150             read_update_time: None,
1151         }
1152     }
1153 }
1154 
1155 impl std::fmt::Display for CryptoStates {
fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result1156     fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
1157         write!(f, "CryptoStates")
1158     }
1159 }
1160 
1161 #[derive(Debug, Default)]
1162 pub struct CryptoStream {
1163     tx: TxBuffer,
1164     rx: RxStreamOrderer,
1165 }
1166 
1167 #[derive(Debug)]
1168 #[allow(dead_code)] // Suppress false positive: https://github.com/rust-lang/rust/issues/68408
1169 pub enum CryptoStreams {
1170     Initial {
1171         initial: CryptoStream,
1172         handshake: CryptoStream,
1173         application: CryptoStream,
1174     },
1175     Handshake {
1176         handshake: CryptoStream,
1177         application: CryptoStream,
1178     },
1179     ApplicationData {
1180         application: CryptoStream,
1181     },
1182 }
1183 
1184 impl CryptoStreams {
discard(&mut self, space: PacketNumberSpace)1185     pub fn discard(&mut self, space: PacketNumberSpace) {
1186         match space {
1187             PacketNumberSpace::Initial => {
1188                 if let Self::Initial {
1189                     handshake,
1190                     application,
1191                     ..
1192                 } = self
1193                 {
1194                     *self = Self::Handshake {
1195                         handshake: mem::take(handshake),
1196                         application: mem::take(application),
1197                     };
1198                 }
1199             }
1200             PacketNumberSpace::Handshake => {
1201                 if let Self::Handshake { application, .. } = self {
1202                     *self = Self::ApplicationData {
1203                         application: mem::take(application),
1204                     };
1205                 } else if matches!(self, Self::Initial { .. }) {
1206                     panic!("Discarding handshake before initial discarded");
1207                 }
1208             }
1209             PacketNumberSpace::ApplicationData => {
1210                 panic!("Discarding application data crypto streams")
1211             }
1212         }
1213     }
1214 
send(&mut self, space: PacketNumberSpace, data: &[u8])1215     pub fn send(&mut self, space: PacketNumberSpace, data: &[u8]) {
1216         self.get_mut(space).unwrap().tx.send(data);
1217     }
1218 
inbound_frame(&mut self, space: PacketNumberSpace, offset: u64, data: &[u8])1219     pub fn inbound_frame(&mut self, space: PacketNumberSpace, offset: u64, data: &[u8]) {
1220         self.get_mut(space).unwrap().rx.inbound_frame(offset, data);
1221     }
1222 
data_ready(&self, space: PacketNumberSpace) -> bool1223     pub fn data_ready(&self, space: PacketNumberSpace) -> bool {
1224         self.get(space).map_or(false, |cs| cs.rx.data_ready())
1225     }
1226 
read_to_end(&mut self, space: PacketNumberSpace, buf: &mut Vec<u8>) -> usize1227     pub fn read_to_end(&mut self, space: PacketNumberSpace, buf: &mut Vec<u8>) -> usize {
1228         self.get_mut(space).unwrap().rx.read_to_end(buf)
1229     }
1230 
acked(&mut self, token: &CryptoRecoveryToken)1231     pub fn acked(&mut self, token: &CryptoRecoveryToken) {
1232         self.get_mut(token.space)
1233             .unwrap()
1234             .tx
1235             .mark_as_acked(token.offset, token.length);
1236     }
1237 
lost(&mut self, token: &CryptoRecoveryToken)1238     pub fn lost(&mut self, token: &CryptoRecoveryToken) {
1239         // See BZ 1624800, ignore lost packets in spaces we've dropped keys
1240         if let Some(cs) = self.get_mut(token.space) {
1241             cs.tx.mark_as_lost(token.offset, token.length);
1242         }
1243     }
1244 
1245     /// Resend any Initial or Handshake CRYPTO frames that might be outstanding.
1246     /// This can help speed up handshake times.
resend_unacked(&mut self, space: PacketNumberSpace)1247     pub fn resend_unacked(&mut self, space: PacketNumberSpace) {
1248         if space != PacketNumberSpace::ApplicationData {
1249             if let Some(cs) = self.get_mut(space) {
1250                 cs.tx.unmark_sent();
1251             }
1252         }
1253     }
1254 
get(&self, space: PacketNumberSpace) -> Option<&CryptoStream>1255     fn get(&self, space: PacketNumberSpace) -> Option<&CryptoStream> {
1256         let (initial, hs, app) = match self {
1257             Self::Initial {
1258                 initial,
1259                 handshake,
1260                 application,
1261             } => (Some(initial), Some(handshake), Some(application)),
1262             Self::Handshake {
1263                 handshake,
1264                 application,
1265             } => (None, Some(handshake), Some(application)),
1266             Self::ApplicationData { application } => (None, None, Some(application)),
1267         };
1268         match space {
1269             PacketNumberSpace::Initial => initial,
1270             PacketNumberSpace::Handshake => hs,
1271             PacketNumberSpace::ApplicationData => app,
1272         }
1273     }
1274 
get_mut(&mut self, space: PacketNumberSpace) -> Option<&mut CryptoStream>1275     fn get_mut(&mut self, space: PacketNumberSpace) -> Option<&mut CryptoStream> {
1276         let (initial, hs, app) = match self {
1277             Self::Initial {
1278                 initial,
1279                 handshake,
1280                 application,
1281             } => (Some(initial), Some(handshake), Some(application)),
1282             Self::Handshake {
1283                 handshake,
1284                 application,
1285             } => (None, Some(handshake), Some(application)),
1286             Self::ApplicationData { application } => (None, None, Some(application)),
1287         };
1288         match space {
1289             PacketNumberSpace::Initial => initial,
1290             PacketNumberSpace::Handshake => hs,
1291             PacketNumberSpace::ApplicationData => app,
1292         }
1293     }
1294 
write_frame( &mut self, space: PacketNumberSpace, builder: &mut PacketBuilder, tokens: &mut Vec<RecoveryToken>, stats: &mut FrameStats, ) -> Res<()>1295     pub fn write_frame(
1296         &mut self,
1297         space: PacketNumberSpace,
1298         builder: &mut PacketBuilder,
1299         tokens: &mut Vec<RecoveryToken>,
1300         stats: &mut FrameStats,
1301     ) -> Res<()> {
1302         let cs = self.get_mut(space).unwrap();
1303         if let Some((offset, data)) = cs.tx.next_bytes() {
1304             let mut header_len = 1 + Encoder::varint_len(offset) + 1;
1305 
1306             // Don't bother if there isn't room for the header and some data.
1307             if builder.remaining() < header_len + 1 {
1308                 return Ok(());
1309             }
1310             // Calculate length of data based on the minimum of:
1311             // - available data
1312             // - remaining space, less the header, which counts only one byte
1313             //   for the length at first to avoid underestimating length
1314             let length = min(data.len(), builder.remaining() - header_len);
1315             header_len += Encoder::varint_len(u64::try_from(length).unwrap()) - 1;
1316             let length = min(data.len(), builder.remaining() - header_len);
1317 
1318             builder.encode_varint(crate::frame::FRAME_TYPE_CRYPTO);
1319             builder.encode_varint(offset);
1320             builder.encode_vvec(&data[..length]);
1321             if builder.len() > builder.limit() {
1322                 return Err(Error::InternalError(15));
1323             }
1324 
1325             cs.tx.mark_as_sent(offset, length);
1326 
1327             qdebug!("CRYPTO for {} offset={}, len={}", space, offset, length);
1328             tokens.push(RecoveryToken::Crypto(CryptoRecoveryToken {
1329                 space,
1330                 offset,
1331                 length,
1332             }));
1333             stats.crypto += 1;
1334         }
1335         Ok(())
1336     }
1337 }
1338 
1339 impl Default for CryptoStreams {
default() -> Self1340     fn default() -> Self {
1341         Self::Initial {
1342             initial: CryptoStream::default(),
1343             handshake: CryptoStream::default(),
1344             application: CryptoStream::default(),
1345         }
1346     }
1347 }
1348 
1349 #[derive(Debug, Clone)]
1350 pub struct CryptoRecoveryToken {
1351     space: PacketNumberSpace,
1352     offset: u64,
1353     length: usize,
1354 }
1355