1 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4 // option. This file may not be copied, modified, or distributed
5 // except according to those terms.
6 
7 pub use crate::agentio::{as_c_void, Record, RecordList};
8 use crate::agentio::{AgentIo, METHODS};
9 use crate::assert_initialized;
10 use crate::auth::AuthenticationStatus;
11 pub use crate::cert::CertificateInfo;
12 use crate::constants::*;
13 use crate::err::{is_blocked, secstatus_to_res, Error, PRErrorCode, Res};
14 use crate::ext::{ExtensionHandler, ExtensionTracker};
15 use crate::p11;
16 use crate::prio;
17 use crate::replay::AntiReplay;
18 use crate::secrets::SecretHolder;
19 use crate::ssl::{self, PRBool};
20 use crate::time::TimeHolder;
21 
22 use neqo_common::{matches, qdebug, qinfo, qtrace, qwarn};
23 use std::cell::RefCell;
24 use std::convert::TryFrom;
25 use std::ffi::CString;
26 use std::mem::{self, MaybeUninit};
27 use std::ops::{Deref, DerefMut};
28 use std::os::raw::{c_uint, c_void};
29 use std::pin::Pin;
30 use std::ptr::{null, null_mut, NonNull};
31 use std::rc::Rc;
32 use std::time::Instant;
33 
34 #[derive(Clone, Debug, PartialEq)]
35 pub enum HandshakeState {
36     New,
37     InProgress,
38     AuthenticationPending,
39     Authenticated(PRErrorCode),
40     Complete(SecretAgentInfo),
41     Failed(Error),
42 }
43 
44 impl HandshakeState {
45     #[must_use]
is_connected(&self) -> bool46     pub fn is_connected(&self) -> bool {
47         matches!(self, Self::Complete(_))
48     }
49 
50     #[must_use]
is_final(&self) -> bool51     pub fn is_final(&self) -> bool {
52         matches!(self, Self::Complete(_) | Self::Failed(_))
53     }
54 }
55 
get_alpn(fd: *mut ssl::PRFileDesc, pre: bool) -> Res<Option<String>>56 fn get_alpn(fd: *mut ssl::PRFileDesc, pre: bool) -> Res<Option<String>> {
57     let mut alpn_state = ssl::SSLNextProtoState::SSL_NEXT_PROTO_NO_SUPPORT;
58     let mut chosen = vec![0_u8; 255];
59     let mut chosen_len: c_uint = 0;
60     secstatus_to_res(unsafe {
61         ssl::SSL_GetNextProto(
62             fd,
63             &mut alpn_state,
64             chosen.as_mut_ptr(),
65             &mut chosen_len,
66             c_uint::try_from(chosen.len())?,
67         )
68     })?;
69 
70     let alpn = match (pre, alpn_state) {
71         (true, ssl::SSLNextProtoState::SSL_NEXT_PROTO_EARLY_VALUE)
72         | (false, ssl::SSLNextProtoState::SSL_NEXT_PROTO_NEGOTIATED)
73         | (false, ssl::SSLNextProtoState::SSL_NEXT_PROTO_SELECTED) => {
74             chosen.truncate(chosen_len as usize);
75             Some(match String::from_utf8(chosen) {
76                 Ok(a) => a,
77                 _ => return Err(Error::InternalError),
78             })
79         }
80         _ => None,
81     };
82     qtrace!([format!("{:p}", fd)], "got ALPN {:?}", alpn);
83     Ok(alpn)
84 }
85 
86 pub struct SecretAgentPreInfo {
87     info: ssl::SSLPreliminaryChannelInfo,
88     alpn: Option<String>,
89 }
90 
91 macro_rules! preinfo_arg {
92     ($v:ident, $m:ident, $f:ident: $t:ident $(,)?) => {
93         #[must_use]
94         pub fn $v(&self) -> Option<$t> {
95             match self.info.valuesSet & ssl::$m {
96                 0 => None,
97                 _ => Some(self.info.$f as $t)
98             }
99         }
100     };
101 }
102 
103 impl SecretAgentPreInfo {
new(fd: *mut ssl::PRFileDesc) -> Res<Self>104     fn new(fd: *mut ssl::PRFileDesc) -> Res<Self> {
105         let mut info: MaybeUninit<ssl::SSLPreliminaryChannelInfo> = MaybeUninit::uninit();
106         secstatus_to_res(unsafe {
107             ssl::SSL_GetPreliminaryChannelInfo(
108                 fd,
109                 info.as_mut_ptr(),
110                 c_uint::try_from(mem::size_of::<ssl::SSLPreliminaryChannelInfo>())?,
111             )
112         })?;
113 
114         Ok(Self {
115             info: unsafe { info.assume_init() },
116             alpn: get_alpn(fd, true)?,
117         })
118     }
119 
120     preinfo_arg!(version, ssl_preinfo_version, protocolVersion: Version);
121     preinfo_arg!(cipher_suite, ssl_preinfo_cipher_suite, cipherSuite: Cipher);
122     #[must_use]
early_data(&self) -> bool123     pub fn early_data(&self) -> bool {
124         self.info.canSendEarlyData != 0
125     }
126     #[must_use]
max_early_data(&self) -> usize127     pub fn max_early_data(&self) -> usize {
128         usize::try_from(self.info.maxEarlyDataSize).unwrap()
129     }
130     #[must_use]
alpn(&self) -> Option<&String>131     pub fn alpn(&self) -> Option<&String> {
132         self.alpn.as_ref()
133     }
134 
135     preinfo_arg!(
136         early_data_cipher,
137         ssl_preinfo_0rtt_cipher_suite,
138         zeroRttCipherSuite: Cipher,
139     );
140 }
141 
142 #[derive(Clone, Debug, Default, PartialEq)]
143 pub struct SecretAgentInfo {
144     version: Version,
145     cipher: Cipher,
146     group: Group,
147     resumed: bool,
148     early_data: bool,
149     alpn: Option<String>,
150     signature_scheme: SignatureScheme,
151 }
152 
153 impl SecretAgentInfo {
new(fd: *mut ssl::PRFileDesc) -> Res<Self>154     fn new(fd: *mut ssl::PRFileDesc) -> Res<Self> {
155         let mut info: MaybeUninit<ssl::SSLChannelInfo> = MaybeUninit::uninit();
156         secstatus_to_res(unsafe {
157             ssl::SSL_GetChannelInfo(
158                 fd,
159                 info.as_mut_ptr(),
160                 c_uint::try_from(mem::size_of::<ssl::SSLChannelInfo>())?,
161             )
162         })?;
163         let info = unsafe { info.assume_init() };
164         Ok(Self {
165             version: info.protocolVersion as Version,
166             cipher: info.cipherSuite as Cipher,
167             group: Group::try_from(info.keaGroup)?,
168             resumed: info.resumed != 0,
169             early_data: info.earlyDataAccepted != 0,
170             alpn: get_alpn(fd, false)?,
171             signature_scheme: SignatureScheme::try_from(info.signatureScheme)?,
172         })
173     }
174     #[must_use]
version(&self) -> Version175     pub fn version(&self) -> Version {
176         self.version
177     }
178     #[must_use]
cipher_suite(&self) -> Cipher179     pub fn cipher_suite(&self) -> Cipher {
180         self.cipher
181     }
182     #[must_use]
key_exchange(&self) -> Group183     pub fn key_exchange(&self) -> Group {
184         self.group
185     }
186     #[must_use]
resumed(&self) -> bool187     pub fn resumed(&self) -> bool {
188         self.resumed
189     }
190     #[must_use]
early_data_accepted(&self) -> bool191     pub fn early_data_accepted(&self) -> bool {
192         self.early_data
193     }
194     #[must_use]
alpn(&self) -> Option<&String>195     pub fn alpn(&self) -> Option<&String> {
196         self.alpn.as_ref()
197     }
198     #[must_use]
signature_scheme(&self) -> SignatureScheme199     pub fn signature_scheme(&self) -> SignatureScheme {
200         self.signature_scheme
201     }
202 }
203 
204 /// `SecretAgent` holds the common parts of client and server.
205 #[derive(Debug)]
206 #[allow(clippy::module_name_repetitions)]
207 pub struct SecretAgent {
208     fd: *mut ssl::PRFileDesc,
209     secrets: SecretHolder,
210     raw: Option<bool>,
211     io: Pin<Box<AgentIo>>,
212     state: HandshakeState,
213 
214     /// Records whether authentication of certificates is required.
215     auth_required: Pin<Box<bool>>,
216     /// Records any fatal alert that is sent by the stack.
217     alert: Pin<Box<Option<Alert>>>,
218     /// The current time.
219     now: TimeHolder,
220 
221     extension_handlers: Vec<ExtensionTracker>,
222     inf: Option<SecretAgentInfo>,
223 
224     /// Whether or not EndOfEarlyData should be suppressed.
225     no_eoed: bool,
226 }
227 
228 impl SecretAgent {
new() -> Res<Self>229     fn new() -> Res<Self> {
230         let mut io = Box::pin(AgentIo::new());
231         let fd = Self::create_fd(&mut io)?;
232         Ok(Self {
233             fd,
234             secrets: SecretHolder::default(),
235             raw: None,
236             io,
237             state: HandshakeState::New,
238 
239             auth_required: Box::pin(false),
240             alert: Box::pin(None),
241             now: TimeHolder::default(),
242 
243             extension_handlers: Vec::new(),
244             inf: None,
245 
246             no_eoed: false,
247         })
248     }
249 
250     // Create a new SSL file descriptor.
251     //
252     // Note that we create separate bindings for PRFileDesc as both
253     // ssl::PRFileDesc and prio::PRFileDesc.  This keeps the bindings
254     // minimal, but it means that the two forms need casts to translate
255     // between them.  ssl::PRFileDesc is left as an opaque type, as the
256     // ssl::SSL_* APIs only need an opaque type.
create_fd(io: &mut Pin<Box<AgentIo>>) -> Res<*mut ssl::PRFileDesc>257     fn create_fd(io: &mut Pin<Box<AgentIo>>) -> Res<*mut ssl::PRFileDesc> {
258         assert_initialized();
259         let label = CString::new("sslwrapper")?;
260         let id = unsafe { prio::PR_GetUniqueIdentity(label.as_ptr()) };
261 
262         let base_fd = unsafe { prio::PR_CreateIOLayerStub(id, METHODS) };
263         if base_fd.is_null() {
264             return Err(Error::CreateSslSocket);
265         }
266         let fd = unsafe {
267             (*base_fd).secret = as_c_void(io) as *mut _;
268             ssl::SSL_ImportFD(null_mut(), base_fd as *mut ssl::PRFileDesc)
269         };
270         if fd.is_null() {
271             unsafe { prio::PR_Close(base_fd) };
272             return Err(Error::CreateSslSocket);
273         }
274         Ok(fd)
275     }
276 
auth_complete_hook( arg: *mut c_void, _fd: *mut ssl::PRFileDesc, _check_sig: ssl::PRBool, _is_server: ssl::PRBool, ) -> ssl::SECStatus277     unsafe extern "C" fn auth_complete_hook(
278         arg: *mut c_void,
279         _fd: *mut ssl::PRFileDesc,
280         _check_sig: ssl::PRBool,
281         _is_server: ssl::PRBool,
282     ) -> ssl::SECStatus {
283         let auth_required_ptr = arg as *mut bool;
284         *auth_required_ptr = true;
285         // NSS insists on getting SECWouldBlock here rather than accepting
286         // the usual combination of PR_WOULD_BLOCK_ERROR and SECFailure.
287         ssl::_SECStatus_SECWouldBlock
288     }
289 
alert_sent_cb( fd: *const ssl::PRFileDesc, arg: *mut c_void, alert: *const ssl::SSLAlert, )290     unsafe extern "C" fn alert_sent_cb(
291         fd: *const ssl::PRFileDesc,
292         arg: *mut c_void,
293         alert: *const ssl::SSLAlert,
294     ) {
295         let alert = alert.as_ref().unwrap();
296         if alert.level == 2 {
297             // Fatal alerts demand attention.
298             let p = arg as *mut Option<Alert>;
299             let st = p.as_mut().unwrap();
300             if st.is_none() {
301                 *st = Some(alert.description);
302             } else {
303                 qwarn!(
304                     [format!("{:p}", fd)],
305                     "duplicate alert {}",
306                     alert.description
307                 );
308             }
309         }
310     }
311 
312     // Ready this for connecting.
ready(&mut self, is_server: bool) -> Res<()>313     fn ready(&mut self, is_server: bool) -> Res<()> {
314         secstatus_to_res(unsafe {
315             ssl::SSL_AuthCertificateHook(
316                 self.fd,
317                 Some(Self::auth_complete_hook),
318                 as_c_void(&mut self.auth_required),
319             )
320         })?;
321 
322         secstatus_to_res(unsafe {
323             ssl::SSL_AlertSentCallback(
324                 self.fd,
325                 Some(Self::alert_sent_cb),
326                 as_c_void(&mut self.alert),
327             )
328         })?;
329 
330         self.now.bind(self.fd)?;
331         self.configure()?;
332         secstatus_to_res(unsafe { ssl::SSL_ResetHandshake(self.fd, is_server as ssl::PRBool) })
333     }
334 
335     /// Default configuration.
336     ///
337     /// # Errors
338     /// If `set_version_range` fails.
configure(&mut self) -> Res<()>339     fn configure(&mut self) -> Res<()> {
340         self.set_version_range(TLS_VERSION_1_3, TLS_VERSION_1_3)?;
341         self.set_option(ssl::Opt::Locking, false)?;
342         self.set_option(ssl::Opt::Tickets, false)?;
343         self.set_option(ssl::Opt::OcspStapling, true)?;
344         Ok(())
345     }
346 
347     /// Set the versions that are supported.
348     ///
349     /// # Errors
350     /// If the range of versions isn't supported.
set_version_range(&mut self, min: Version, max: Version) -> Res<()>351     pub fn set_version_range(&mut self, min: Version, max: Version) -> Res<()> {
352         let range = ssl::SSLVersionRange {
353             min: min as ssl::PRUint16,
354             max: max as ssl::PRUint16,
355         };
356         secstatus_to_res(unsafe { ssl::SSL_VersionRangeSet(self.fd, &range) })
357     }
358 
359     /// Enable a set of ciphers.  Note that the order of these is not respected.
360     ///
361     /// # Errors
362     /// If NSS can't enable or disable ciphers.
enable_ciphers(&mut self, ciphers: &[Cipher]) -> Res<()>363     pub fn enable_ciphers(&mut self, ciphers: &[Cipher]) -> Res<()> {
364         let all_ciphers = unsafe { ssl::SSL_GetImplementedCiphers() };
365         let cipher_count = unsafe { ssl::SSL_GetNumImplementedCiphers() } as usize;
366         for i in 0..cipher_count {
367             let p = all_ciphers.wrapping_add(i);
368             secstatus_to_res(unsafe {
369                 ssl::SSL_CipherPrefSet(self.fd, i32::from(*p), false as ssl::PRBool)
370             })?;
371         }
372 
373         for c in ciphers {
374             secstatus_to_res(unsafe {
375                 ssl::SSL_CipherPrefSet(self.fd, i32::from(*c), true as ssl::PRBool)
376             })?;
377         }
378         Ok(())
379     }
380 
381     /// Set key exchange groups.
382     ///
383     /// # Errors
384     /// If the underlying API fails (which shouldn't happen).
set_groups(&mut self, groups: &[Group]) -> Res<()>385     pub fn set_groups(&mut self, groups: &[Group]) -> Res<()> {
386         // SSLNamedGroup is a different size to Group, so copy one by one.
387         let group_vec: Vec<_> = groups
388             .iter()
389             .map(|&g| ssl::SSLNamedGroup::Type::from(g))
390             .collect();
391 
392         let ptr = group_vec.as_slice().as_ptr();
393         secstatus_to_res(unsafe {
394             ssl::SSL_NamedGroupConfig(self.fd, ptr, c_uint::try_from(group_vec.len())?)
395         })
396     }
397 
398     /// Set TLS options.
399     ///
400     /// # Errors
401     /// Returns an error if the option or option value is invalid; i.e., never.
set_option(&mut self, opt: ssl::Opt, value: bool) -> Res<()>402     pub fn set_option(&mut self, opt: ssl::Opt, value: bool) -> Res<()> {
403         secstatus_to_res(unsafe {
404             ssl::SSL_OptionSet(self.fd, opt.as_int(), opt.map_enabled(value))
405         })
406     }
407 
408     /// Enable 0-RTT.
409     ///
410     /// # Errors
411     /// See `set_option`.
enable_0rtt(&mut self) -> Res<()>412     pub fn enable_0rtt(&mut self) -> Res<()> {
413         self.set_option(ssl::Opt::EarlyData, true)
414     }
415 
416     /// Disable the `EndOfEarlyData` message.
disable_end_of_early_data(&mut self)417     pub fn disable_end_of_early_data(&mut self) {
418         self.no_eoed = true;
419     }
420 
421     /// `set_alpn` sets a list of preferred protocols, starting with the most preferred.
422     /// Though ALPN [RFC7301] permits octet sequences, this only allows for UTF-8-encoded
423     /// strings.
424     ///
425     /// This asserts if no items are provided, or if any individual item is longer than
426     /// 255 octets in length.
427     ///
428     /// # Errors
429     /// This should always panic rather than return an error.
set_alpn(&mut self, protocols: &[impl AsRef<str>]) -> Res<()>430     pub fn set_alpn(&mut self, protocols: &[impl AsRef<str>]) -> Res<()> {
431         // Validate and set length.
432         let mut encoded_len = protocols.len();
433         for v in protocols {
434             assert!(v.as_ref().len() < 256);
435             encoded_len += v.as_ref().len();
436         }
437 
438         // Prepare to encode.
439         let mut encoded = Vec::with_capacity(encoded_len);
440         let mut add = |v: &str| {
441             if let Ok(s) = u8::try_from(v.len()) {
442                 encoded.push(s);
443                 encoded.extend_from_slice(v.as_bytes());
444             }
445         };
446 
447         // NSS inherited an idiosyncratic API as a result of having implemented NPN
448         // before ALPN.  For that reason, we need to put the "best" option last.
449         let (first, rest) = protocols
450             .split_first()
451             .expect("at least one ALPN value needed");
452         for v in rest {
453             add(v.as_ref());
454         }
455         add(first.as_ref());
456         assert_eq!(encoded_len, encoded.len());
457 
458         // Now give the result to NSS.
459         secstatus_to_res(unsafe {
460             ssl::SSL_SetNextProtoNego(
461                 self.fd,
462                 encoded.as_slice().as_ptr(),
463                 c_uint::try_from(encoded.len())?,
464             )
465         })
466     }
467 
468     /// Install an extension handler.
469     ///
470     /// This can be called multiple times with different values for `ext`.  The handler is provided as
471     /// Rc<RefCell<>> so that the caller is able to hold a reference to the handler and later access any
472     /// state that it accumulates.
473     ///
474     /// # Errors
475     /// When the extension handler can't be successfully installed.
extension_handler( &mut self, ext: Extension, handler: Rc<RefCell<dyn ExtensionHandler>>, ) -> Res<()>476     pub fn extension_handler(
477         &mut self,
478         ext: Extension,
479         handler: Rc<RefCell<dyn ExtensionHandler>>,
480     ) -> Res<()> {
481         let tracker = unsafe { ExtensionTracker::new(self.fd, ext, handler) }?;
482         self.extension_handlers.push(tracker);
483         Ok(())
484     }
485 
486     // This function tracks whether handshake() or handshake_raw() was used
487     // and prevents the other from being used.
set_raw(&mut self, r: bool) -> Res<()>488     fn set_raw(&mut self, r: bool) -> Res<()> {
489         if self.raw.is_none() {
490             self.secrets.register(self.fd)?;
491             self.raw = Some(r);
492             Ok(())
493         } else if self.raw.unwrap() == r {
494             Ok(())
495         } else {
496             Err(Error::MixedHandshakeMethod)
497         }
498     }
499 
500     /// Get information about the connection.
501     /// This includes the version, ciphersuite, and ALPN.
502     ///
503     /// Calling this function returns None until the connection is complete.
504     #[must_use]
info(&self) -> Option<&SecretAgentInfo>505     pub fn info(&self) -> Option<&SecretAgentInfo> {
506         match self.state {
507             HandshakeState::Complete(ref info) => Some(info),
508             _ => None,
509         }
510     }
511 
512     /// Get any preliminary information about the status of the connection.
513     ///
514     /// This includes whether 0-RTT was accepted and any information related to that.
515     /// Calling this function collects all the relevant information.
516     ///
517     /// # Errors
518     /// When the underlying socket functions fail.
preinfo(&self) -> Res<SecretAgentPreInfo>519     pub fn preinfo(&self) -> Res<SecretAgentPreInfo> {
520         SecretAgentPreInfo::new(self.fd)
521     }
522 
523     /// Get the peer's certificate chain.
524     #[must_use]
peer_certificate(&self) -> Option<CertificateInfo>525     pub fn peer_certificate(&self) -> Option<CertificateInfo> {
526         CertificateInfo::new(self.fd)
527     }
528 
529     /// Return any fatal alert that the TLS stack might have sent.
530     #[must_use]
alert(&self) -> Option<&Alert>531     pub fn alert(&self) -> Option<&Alert> {
532         (&*self.alert).as_ref()
533     }
534 
535     /// Call this function to mark the peer as authenticated.
536     /// Only call this function if `handshake/handshake_raw` returns
537     /// `HandshakeState::AuthenticationPending`, or it will panic.
authenticated(&mut self, status: AuthenticationStatus)538     pub fn authenticated(&mut self, status: AuthenticationStatus) {
539         assert_eq!(self.state, HandshakeState::AuthenticationPending);
540         *self.auth_required = false;
541         self.state = HandshakeState::Authenticated(status.into());
542     }
543 
capture_error<T>(&mut self, res: Res<T>) -> Res<T>544     fn capture_error<T>(&mut self, res: Res<T>) -> Res<T> {
545         if let Err(e) = &res {
546             qwarn!([self], "error: {:?}", e);
547             self.state = HandshakeState::Failed(e.clone());
548         }
549         res
550     }
551 
update_state(&mut self, res: Res<()>) -> Res<()>552     fn update_state(&mut self, res: Res<()>) -> Res<()> {
553         self.state = if is_blocked(&res) {
554             if *self.auth_required {
555                 HandshakeState::AuthenticationPending
556             } else {
557                 HandshakeState::InProgress
558             }
559         } else {
560             self.capture_error(res)?;
561             let info = self.capture_error(SecretAgentInfo::new(self.fd))?;
562             HandshakeState::Complete(info)
563         };
564         qinfo!([self], "state -> {:?}", self.state);
565         Ok(())
566     }
567 
568     /// Drive the TLS handshake, taking bytes from `input` and putting
569     /// any bytes necessary into `output`.
570     /// This takes the current time as `now`.
571     /// On success a tuple of a `HandshakeState` and usize indicate whether the handshake
572     /// is complete and how many bytes were written to `output`, respectively.
573     /// If the state is `HandshakeState::AuthenticationPending`, then ONLY call this
574     /// function if you want to proceed, because this will mark the certificate as OK.
575     ///
576     /// # Errors
577     /// When the handshake fails this returns an error.
handshake(&mut self, now: Instant, input: &[u8]) -> Res<Vec<u8>>578     pub fn handshake(&mut self, now: Instant, input: &[u8]) -> Res<Vec<u8>> {
579         self.now.set(now)?;
580         self.set_raw(false)?;
581 
582         let rv = {
583             // Within this scope, _h maintains a mutable reference to self.io.
584             let _h = self.io.wrap(input);
585             match self.state {
586                 HandshakeState::Authenticated(ref err) => unsafe {
587                     ssl::SSL_AuthCertificateComplete(self.fd, *err)
588                 },
589                 _ => unsafe { ssl::SSL_ForceHandshake(self.fd) },
590             }
591         };
592         // Take before updating state so that we leave the output buffer empty
593         // even if there is an error.
594         let output = self.io.take_output();
595         self.update_state(secstatus_to_res(rv))?;
596         Ok(output)
597     }
598 
599     /// Setup to receive records for raw handshake functions.
setup_raw(&mut self) -> Res<Pin<Box<RecordList>>>600     fn setup_raw(&mut self) -> Res<Pin<Box<RecordList>>> {
601         self.set_raw(true)?;
602         self.capture_error(RecordList::setup(self.fd))
603     }
604 
inject_eoed(&mut self) -> Res<()>605     fn inject_eoed(&mut self) -> Res<()> {
606         // EndOfEarlyData is as follows:
607         // struct {
608         //    HandshakeType msg_type = end_of_early_data(5);
609         //    uint24 length = 0;
610         // };
611         const END_OF_EARLY_DATA: &[u8] = &[5, 0, 0, 0];
612 
613         if self.no_eoed {
614             let mut read_epoch: u16 = 0;
615             unsafe { ssl::SSL_GetCurrentEpoch(self.fd, &mut read_epoch, null_mut()) }?;
616             if read_epoch == 1 {
617                 // It's waiting for EndOfEarlyData, so feed one in.
618                 // Note that this is the test that ensures that we only do this for the server.
619                 let eoed = Record::new(1, 22, END_OF_EARLY_DATA);
620                 self.capture_error(eoed.write(self.fd))?;
621                 self.no_eoed = false;
622             }
623         }
624         Ok(())
625     }
626 
627     /// Drive the TLS handshake, but get the raw content of records, not
628     /// protected records as bytes. This function is incompatible with
629     /// `handshake()`; use either this or `handshake()` exclusively.
630     ///
631     /// Ideally, this only includes records from the current epoch.
632     /// If you send data from multiple epochs, you might end up being sad.
633     ///
634     /// # Errors
635     /// When the handshake fails this returns an error.
handshake_raw(&mut self, now: Instant, input: Option<Record>) -> Res<RecordList>636     pub fn handshake_raw(&mut self, now: Instant, input: Option<Record>) -> Res<RecordList> {
637         self.now.set(now)?;
638         let mut records = self.setup_raw()?;
639 
640         // Fire off any authentication we might need to complete.
641         if let HandshakeState::Authenticated(ref err) = self.state {
642             let result =
643                 secstatus_to_res(unsafe { ssl::SSL_AuthCertificateComplete(self.fd, *err) });
644             qdebug!([self], "SSL_AuthCertificateComplete: {:?}", result);
645             // This should return SECSuccess, so don't use update_state().
646             self.capture_error(result)?;
647         }
648 
649         // Feed in any records.
650         if let Some(rec) = input {
651             if rec.epoch == 2 {
652                 self.inject_eoed()?;
653             }
654             self.capture_error(rec.write(self.fd))?;
655         }
656 
657         // Drive the handshake once more.
658         let rv = secstatus_to_res(unsafe { ssl::SSL_ForceHandshake(self.fd) });
659         self.update_state(rv)?;
660 
661         if self.no_eoed {
662             records.remove_eoed();
663         }
664 
665         Ok(*Pin::into_inner(records))
666     }
667 
close(&mut self)668     pub fn close(&mut self) {
669         // It should be safe to close multiple times.
670         if self.fd.is_null() {
671             return;
672         }
673         if let Some(true) = self.raw {
674             // Need to hold the record list in scope until the close is done.
675             let _records = self.setup_raw().expect("Can only close");
676             unsafe { prio::PR_Close(self.fd as *mut prio::PRFileDesc) };
677         } else {
678             // Need to hold the IO wrapper in scope until the close is done.
679             let _io = self.io.wrap(&[]);
680             unsafe { prio::PR_Close(self.fd as *mut prio::PRFileDesc) };
681         };
682         let _output = self.io.take_output();
683         self.fd = null_mut();
684     }
685 
686     /// State returns the status of the handshake.
687     #[must_use]
state(&self) -> &HandshakeState688     pub fn state(&self) -> &HandshakeState {
689         &self.state
690     }
691     /// Take a read secret.  This will only return a non-`None` value once.
692     #[must_use]
read_secret(&mut self, epoch: Epoch) -> Option<p11::SymKey>693     pub fn read_secret(&mut self, epoch: Epoch) -> Option<p11::SymKey> {
694         self.secrets.take_read(epoch)
695     }
696     /// Take a write secret.
697     #[must_use]
write_secret(&mut self, epoch: Epoch) -> Option<p11::SymKey>698     pub fn write_secret(&mut self, epoch: Epoch) -> Option<p11::SymKey> {
699         self.secrets.take_write(epoch)
700     }
701 }
702 
703 impl Drop for SecretAgent {
drop(&mut self)704     fn drop(&mut self) {
705         self.close();
706     }
707 }
708 
709 impl ::std::fmt::Display for SecretAgent {
fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result710     fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
711         write!(f, "Agent {:p}", self.fd)
712     }
713 }
714 
715 /// A TLS Client.
716 #[derive(Debug)]
717 pub struct Client {
718     agent: SecretAgent,
719 
720     /// Records the last resumption token.
721     resumption: Pin<Box<Option<Vec<u8>>>>,
722 }
723 
724 impl Client {
725     /// Create a new client agent.
726     ///
727     /// # Errors
728     /// Errors returned if the socket can't be created or configured.
new(server_name: &str) -> Res<Self>729     pub fn new(server_name: &str) -> Res<Self> {
730         let mut agent = SecretAgent::new()?;
731         let url = CString::new(server_name)?;
732         secstatus_to_res(unsafe { ssl::SSL_SetURL(agent.fd, url.as_ptr()) })?;
733         agent.ready(false)?;
734         let mut client = Self {
735             agent,
736             resumption: Box::pin(None),
737         };
738         client.ready()?;
739         Ok(client)
740     }
741 
resumption_token_cb( fd: *mut ssl::PRFileDesc, token: *const u8, len: c_uint, arg: *mut c_void, ) -> ssl::SECStatus742     unsafe extern "C" fn resumption_token_cb(
743         fd: *mut ssl::PRFileDesc,
744         token: *const u8,
745         len: c_uint,
746         arg: *mut c_void,
747     ) -> ssl::SECStatus {
748         let resumption_ptr = arg as *mut Option<Vec<u8>>;
749         let resumption = resumption_ptr.as_mut().unwrap();
750         let mut v = Vec::with_capacity(len as usize);
751         v.extend_from_slice(std::slice::from_raw_parts(token, len as usize));
752         qdebug!([format!("{:p}", fd)], "Got resumption token");
753         *resumption = Some(v);
754         ssl::SECSuccess
755     }
756 
ready(&mut self) -> Res<()>757     fn ready(&mut self) -> Res<()> {
758         let fd = self.fd;
759         unsafe {
760             ssl::SSL_SetResumptionTokenCallback(
761                 fd,
762                 Some(Self::resumption_token_cb),
763                 as_c_void(&mut self.resumption),
764             )
765         }
766     }
767 
768     /// Return the resumption token.
769     #[must_use]
resumption_token(&self) -> Option<&Vec<u8>>770     pub fn resumption_token(&self) -> Option<&Vec<u8>> {
771         (*self.resumption).as_ref()
772     }
773 
774     /// Enable resumption, using a token previously provided.
775     ///
776     /// # Errors
777     /// Error returned when the resumption token is invalid or
778     /// the socket is not able to use the value.
set_resumption_token(&mut self, token: &[u8]) -> Res<()>779     pub fn set_resumption_token(&mut self, token: &[u8]) -> Res<()> {
780         unsafe {
781             ssl::SSL_SetResumptionToken(
782                 self.agent.fd,
783                 token.as_ptr(),
784                 c_uint::try_from(token.len())?,
785             )
786         }
787     }
788 }
789 
790 impl Deref for Client {
791     type Target = SecretAgent;
792     #[must_use]
deref(&self) -> &SecretAgent793     fn deref(&self) -> &SecretAgent {
794         &self.agent
795     }
796 }
797 
798 impl DerefMut for Client {
deref_mut(&mut self) -> &mut SecretAgent799     fn deref_mut(&mut self) -> &mut SecretAgent {
800         &mut self.agent
801     }
802 }
803 
804 /// `ZeroRttCheckResult` encapsulates the options for handling a `ClientHello`.
805 #[derive(Clone, Debug, PartialEq)]
806 pub enum ZeroRttCheckResult {
807     /// Accept 0-RTT; the default.
808     Accept,
809     /// Reject 0-RTT, but continue the handshake normally.
810     Reject,
811     /// Send HelloRetryRequest (probably not needed for QUIC).
812     HelloRetryRequest(Vec<u8>),
813     /// Fail the handshake.
814     Fail,
815 }
816 
817 /// A `ZeroRttChecker` is used by the agent to validate the application token (as provided by `send_ticket`)
818 pub trait ZeroRttChecker: std::fmt::Debug + std::marker::Unpin {
check(&self, token: &[u8]) -> ZeroRttCheckResult819     fn check(&self, token: &[u8]) -> ZeroRttCheckResult;
820 }
821 
822 #[derive(Debug)]
823 struct ZeroRttCheckState {
824     fd: *mut ssl::PRFileDesc,
825     checker: Pin<Box<dyn ZeroRttChecker>>,
826 }
827 
828 impl ZeroRttCheckState {
new(fd: *mut ssl::PRFileDesc, checker: Box<dyn ZeroRttChecker>) -> Self829     pub fn new(fd: *mut ssl::PRFileDesc, checker: Box<dyn ZeroRttChecker>) -> Self {
830         Self {
831             fd,
832             checker: Pin::new(checker),
833         }
834     }
835 }
836 
837 #[derive(Debug)]
838 pub struct Server {
839     agent: SecretAgent,
840     /// This holds the HRR callback context.
841     zero_rtt_check: Option<Pin<Box<ZeroRttCheckState>>>,
842 }
843 
844 impl Server {
845     /// Create a new server agent.
846     ///
847     /// # Errors
848     /// Errors returned when NSS fails.
new(certificates: &[impl AsRef<str>]) -> Res<Self>849     pub fn new(certificates: &[impl AsRef<str>]) -> Res<Self> {
850         let mut agent = SecretAgent::new()?;
851 
852         for n in certificates {
853             let c = CString::new(n.as_ref())?;
854             let cert = match NonNull::new(unsafe {
855                 p11::PK11_FindCertFromNickname(c.as_ptr(), null_mut())
856             }) {
857                 None => return Err(Error::CertificateLoading),
858                 Some(ptr) => p11::Certificate::new(ptr),
859             };
860             let key = match NonNull::new(unsafe {
861                 p11::PK11_FindKeyByAnyCert(*cert.deref(), null_mut())
862             }) {
863                 None => return Err(Error::CertificateLoading),
864                 Some(ptr) => p11::PrivateKey::new(ptr),
865             };
866             secstatus_to_res(unsafe {
867                 ssl::SSL_ConfigServerCert(agent.fd, *cert.deref(), *key.deref(), null(), 0)
868             })?;
869         }
870 
871         agent.ready(true)?;
872         Ok(Self {
873             agent,
874             zero_rtt_check: None,
875         })
876     }
877 
hello_retry_cb( first_hello: PRBool, client_token: *const u8, client_token_len: c_uint, retry_token: *mut u8, retry_token_len: *mut c_uint, retry_token_max: c_uint, arg: *mut c_void, ) -> ssl::SSLHelloRetryRequestAction::Type878     unsafe extern "C" fn hello_retry_cb(
879         first_hello: PRBool,
880         client_token: *const u8,
881         client_token_len: c_uint,
882         retry_token: *mut u8,
883         retry_token_len: *mut c_uint,
884         retry_token_max: c_uint,
885         arg: *mut c_void,
886     ) -> ssl::SSLHelloRetryRequestAction::Type {
887         if first_hello == 0 {
888             // On the second ClientHello after HelloRetryRequest, skip checks.
889             return ssl::SSLHelloRetryRequestAction::ssl_hello_retry_accept;
890         }
891 
892         let p = arg as *mut ZeroRttCheckState;
893         let check_state = p.as_mut().unwrap();
894         let token = if client_token.is_null() {
895             &[]
896         } else {
897             std::slice::from_raw_parts(client_token, client_token_len as usize)
898         };
899         match check_state.checker.check(token) {
900             ZeroRttCheckResult::Accept => ssl::SSLHelloRetryRequestAction::ssl_hello_retry_accept,
901             ZeroRttCheckResult::Fail => ssl::SSLHelloRetryRequestAction::ssl_hello_retry_fail,
902             ZeroRttCheckResult::Reject => {
903                 ssl::SSLHelloRetryRequestAction::ssl_hello_retry_reject_0rtt
904             }
905             ZeroRttCheckResult::HelloRetryRequest(tok) => {
906                 // Don't bother propagating errors from this, because it should be caught in testing.
907                 assert!(tok.len() <= usize::try_from(retry_token_max).unwrap());
908                 let slc = std::slice::from_raw_parts_mut(retry_token, tok.len());
909                 slc.copy_from_slice(&tok);
910                 *retry_token_len = c_uint::try_from(tok.len()).expect("token was way too big");
911                 ssl::SSLHelloRetryRequestAction::ssl_hello_retry_request
912             }
913         }
914     }
915 
916     /// Enable 0-RTT.  This shadows the function of the same name that can be accessed
917     /// via the Deref implementation on Server.
918     ///
919     /// # Errors
920     /// Returns an error if the underlying NSS functions fail.
enable_0rtt( &mut self, anti_replay: &AntiReplay, max_early_data: u32, checker: Box<dyn ZeroRttChecker>, ) -> Res<()>921     pub fn enable_0rtt(
922         &mut self,
923         anti_replay: &AntiReplay,
924         max_early_data: u32,
925         checker: Box<dyn ZeroRttChecker>,
926     ) -> Res<()> {
927         let mut check_state = Box::pin(ZeroRttCheckState::new(self.agent.fd, checker));
928         unsafe {
929             ssl::SSL_HelloRetryRequestCallback(
930                 self.agent.fd,
931                 Some(Self::hello_retry_cb),
932                 as_c_void(&mut check_state),
933             )
934         }?;
935         unsafe { ssl::SSL_SetMaxEarlyDataSize(self.agent.fd, max_early_data) }?;
936         self.zero_rtt_check = Some(check_state);
937         self.agent.enable_0rtt()?;
938         anti_replay.config_socket(self.fd)?;
939         Ok(())
940     }
941 
942     /// Send a session ticket to the client.
943     /// This adds |extra| application-specific content into that ticket.
944     /// The records that are sent are captured and returned.
945     ///
946     /// # Errors
947     /// If NSS is unable to send a ticket, or if this agent is incorrectly configured.
send_ticket(&mut self, now: Instant, extra: &[u8]) -> Res<RecordList>948     pub fn send_ticket(&mut self, now: Instant, extra: &[u8]) -> Res<RecordList> {
949         self.agent.now.set(now)?;
950         let records = self.setup_raw()?;
951 
952         unsafe {
953             ssl::SSL_SendSessionTicket(self.fd, extra.as_ptr(), c_uint::try_from(extra.len())?)
954         }?;
955 
956         Ok(*Pin::into_inner(records))
957     }
958 }
959 
960 impl Deref for Server {
961     type Target = SecretAgent;
962     #[must_use]
deref(&self) -> &SecretAgent963     fn deref(&self) -> &SecretAgent {
964         &self.agent
965     }
966 }
967 
968 impl DerefMut for Server {
deref_mut(&mut self) -> &mut SecretAgent969     fn deref_mut(&mut self) -> &mut SecretAgent {
970         &mut self.agent
971     }
972 }
973 
974 /// A generic container for Client or Server.
975 #[derive(Debug)]
976 pub enum Agent {
977     Client(crate::agent::Client),
978     Server(crate::agent::Server),
979 }
980 
981 impl Deref for Agent {
982     type Target = SecretAgent;
983     #[must_use]
deref(&self) -> &SecretAgent984     fn deref(&self) -> &SecretAgent {
985         match self {
986             Self::Client(c) => &*c,
987             Self::Server(s) => &*s,
988         }
989     }
990 }
991 
992 impl DerefMut for Agent {
deref_mut(&mut self) -> &mut SecretAgent993     fn deref_mut(&mut self) -> &mut SecretAgent {
994         match self {
995             Self::Client(c) => c.deref_mut(),
996             Self::Server(s) => s.deref_mut(),
997         }
998     }
999 }
1000 
1001 impl From<Client> for Agent {
1002     #[must_use]
from(c: Client) -> Self1003     fn from(c: Client) -> Self {
1004         Self::Client(c)
1005     }
1006 }
1007 
1008 impl From<Server> for Agent {
1009     #[must_use]
from(s: Server) -> Self1010     fn from(s: Server) -> Self {
1011         Self::Server(s)
1012     }
1013 }
1014