1 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4 // option. This file may not be copied, modified, or distributed
5 // except according to those terms.
6
7 pub use crate::agentio::{as_c_void, Record, RecordList};
8 use crate::agentio::{AgentIo, METHODS};
9 use crate::assert_initialized;
10 use crate::auth::AuthenticationStatus;
11 pub use crate::cert::CertificateInfo;
12 use crate::constants::*;
13 use crate::err::{is_blocked, secstatus_to_res, Error, PRErrorCode, Res};
14 use crate::ext::{ExtensionHandler, ExtensionTracker};
15 use crate::p11;
16 use crate::prio;
17 use crate::replay::AntiReplay;
18 use crate::secrets::SecretHolder;
19 use crate::ssl::{self, PRBool};
20 use crate::time::TimeHolder;
21
22 use neqo_common::{matches, qdebug, qinfo, qtrace, qwarn};
23 use std::cell::RefCell;
24 use std::convert::TryFrom;
25 use std::ffi::CString;
26 use std::mem::{self, MaybeUninit};
27 use std::ops::{Deref, DerefMut};
28 use std::os::raw::{c_uint, c_void};
29 use std::pin::Pin;
30 use std::ptr::{null, null_mut, NonNull};
31 use std::rc::Rc;
32 use std::time::Instant;
33
34 #[derive(Clone, Debug, PartialEq)]
35 pub enum HandshakeState {
36 New,
37 InProgress,
38 AuthenticationPending,
39 Authenticated(PRErrorCode),
40 Complete(SecretAgentInfo),
41 Failed(Error),
42 }
43
44 impl HandshakeState {
45 #[must_use]
is_connected(&self) -> bool46 pub fn is_connected(&self) -> bool {
47 matches!(self, Self::Complete(_))
48 }
49
50 #[must_use]
is_final(&self) -> bool51 pub fn is_final(&self) -> bool {
52 matches!(self, Self::Complete(_) | Self::Failed(_))
53 }
54 }
55
get_alpn(fd: *mut ssl::PRFileDesc, pre: bool) -> Res<Option<String>>56 fn get_alpn(fd: *mut ssl::PRFileDesc, pre: bool) -> Res<Option<String>> {
57 let mut alpn_state = ssl::SSLNextProtoState::SSL_NEXT_PROTO_NO_SUPPORT;
58 let mut chosen = vec![0_u8; 255];
59 let mut chosen_len: c_uint = 0;
60 secstatus_to_res(unsafe {
61 ssl::SSL_GetNextProto(
62 fd,
63 &mut alpn_state,
64 chosen.as_mut_ptr(),
65 &mut chosen_len,
66 c_uint::try_from(chosen.len())?,
67 )
68 })?;
69
70 let alpn = match (pre, alpn_state) {
71 (true, ssl::SSLNextProtoState::SSL_NEXT_PROTO_EARLY_VALUE)
72 | (false, ssl::SSLNextProtoState::SSL_NEXT_PROTO_NEGOTIATED)
73 | (false, ssl::SSLNextProtoState::SSL_NEXT_PROTO_SELECTED) => {
74 chosen.truncate(chosen_len as usize);
75 Some(match String::from_utf8(chosen) {
76 Ok(a) => a,
77 _ => return Err(Error::InternalError),
78 })
79 }
80 _ => None,
81 };
82 qtrace!([format!("{:p}", fd)], "got ALPN {:?}", alpn);
83 Ok(alpn)
84 }
85
86 pub struct SecretAgentPreInfo {
87 info: ssl::SSLPreliminaryChannelInfo,
88 alpn: Option<String>,
89 }
90
91 macro_rules! preinfo_arg {
92 ($v:ident, $m:ident, $f:ident: $t:ident $(,)?) => {
93 #[must_use]
94 pub fn $v(&self) -> Option<$t> {
95 match self.info.valuesSet & ssl::$m {
96 0 => None,
97 _ => Some(self.info.$f as $t)
98 }
99 }
100 };
101 }
102
103 impl SecretAgentPreInfo {
new(fd: *mut ssl::PRFileDesc) -> Res<Self>104 fn new(fd: *mut ssl::PRFileDesc) -> Res<Self> {
105 let mut info: MaybeUninit<ssl::SSLPreliminaryChannelInfo> = MaybeUninit::uninit();
106 secstatus_to_res(unsafe {
107 ssl::SSL_GetPreliminaryChannelInfo(
108 fd,
109 info.as_mut_ptr(),
110 c_uint::try_from(mem::size_of::<ssl::SSLPreliminaryChannelInfo>())?,
111 )
112 })?;
113
114 Ok(Self {
115 info: unsafe { info.assume_init() },
116 alpn: get_alpn(fd, true)?,
117 })
118 }
119
120 preinfo_arg!(version, ssl_preinfo_version, protocolVersion: Version);
121 preinfo_arg!(cipher_suite, ssl_preinfo_cipher_suite, cipherSuite: Cipher);
122 #[must_use]
early_data(&self) -> bool123 pub fn early_data(&self) -> bool {
124 self.info.canSendEarlyData != 0
125 }
126 #[must_use]
max_early_data(&self) -> usize127 pub fn max_early_data(&self) -> usize {
128 usize::try_from(self.info.maxEarlyDataSize).unwrap()
129 }
130 #[must_use]
alpn(&self) -> Option<&String>131 pub fn alpn(&self) -> Option<&String> {
132 self.alpn.as_ref()
133 }
134
135 preinfo_arg!(
136 early_data_cipher,
137 ssl_preinfo_0rtt_cipher_suite,
138 zeroRttCipherSuite: Cipher,
139 );
140 }
141
142 #[derive(Clone, Debug, Default, PartialEq)]
143 pub struct SecretAgentInfo {
144 version: Version,
145 cipher: Cipher,
146 group: Group,
147 resumed: bool,
148 early_data: bool,
149 alpn: Option<String>,
150 signature_scheme: SignatureScheme,
151 }
152
153 impl SecretAgentInfo {
new(fd: *mut ssl::PRFileDesc) -> Res<Self>154 fn new(fd: *mut ssl::PRFileDesc) -> Res<Self> {
155 let mut info: MaybeUninit<ssl::SSLChannelInfo> = MaybeUninit::uninit();
156 secstatus_to_res(unsafe {
157 ssl::SSL_GetChannelInfo(
158 fd,
159 info.as_mut_ptr(),
160 c_uint::try_from(mem::size_of::<ssl::SSLChannelInfo>())?,
161 )
162 })?;
163 let info = unsafe { info.assume_init() };
164 Ok(Self {
165 version: info.protocolVersion as Version,
166 cipher: info.cipherSuite as Cipher,
167 group: Group::try_from(info.keaGroup)?,
168 resumed: info.resumed != 0,
169 early_data: info.earlyDataAccepted != 0,
170 alpn: get_alpn(fd, false)?,
171 signature_scheme: SignatureScheme::try_from(info.signatureScheme)?,
172 })
173 }
174 #[must_use]
version(&self) -> Version175 pub fn version(&self) -> Version {
176 self.version
177 }
178 #[must_use]
cipher_suite(&self) -> Cipher179 pub fn cipher_suite(&self) -> Cipher {
180 self.cipher
181 }
182 #[must_use]
key_exchange(&self) -> Group183 pub fn key_exchange(&self) -> Group {
184 self.group
185 }
186 #[must_use]
resumed(&self) -> bool187 pub fn resumed(&self) -> bool {
188 self.resumed
189 }
190 #[must_use]
early_data_accepted(&self) -> bool191 pub fn early_data_accepted(&self) -> bool {
192 self.early_data
193 }
194 #[must_use]
alpn(&self) -> Option<&String>195 pub fn alpn(&self) -> Option<&String> {
196 self.alpn.as_ref()
197 }
198 #[must_use]
signature_scheme(&self) -> SignatureScheme199 pub fn signature_scheme(&self) -> SignatureScheme {
200 self.signature_scheme
201 }
202 }
203
204 /// `SecretAgent` holds the common parts of client and server.
205 #[derive(Debug)]
206 #[allow(clippy::module_name_repetitions)]
207 pub struct SecretAgent {
208 fd: *mut ssl::PRFileDesc,
209 secrets: SecretHolder,
210 raw: Option<bool>,
211 io: Pin<Box<AgentIo>>,
212 state: HandshakeState,
213
214 /// Records whether authentication of certificates is required.
215 auth_required: Pin<Box<bool>>,
216 /// Records any fatal alert that is sent by the stack.
217 alert: Pin<Box<Option<Alert>>>,
218 /// The current time.
219 now: TimeHolder,
220
221 extension_handlers: Vec<ExtensionTracker>,
222 inf: Option<SecretAgentInfo>,
223
224 /// Whether or not EndOfEarlyData should be suppressed.
225 no_eoed: bool,
226 }
227
228 impl SecretAgent {
new() -> Res<Self>229 fn new() -> Res<Self> {
230 let mut io = Box::pin(AgentIo::new());
231 let fd = Self::create_fd(&mut io)?;
232 Ok(Self {
233 fd,
234 secrets: SecretHolder::default(),
235 raw: None,
236 io,
237 state: HandshakeState::New,
238
239 auth_required: Box::pin(false),
240 alert: Box::pin(None),
241 now: TimeHolder::default(),
242
243 extension_handlers: Vec::new(),
244 inf: None,
245
246 no_eoed: false,
247 })
248 }
249
250 // Create a new SSL file descriptor.
251 //
252 // Note that we create separate bindings for PRFileDesc as both
253 // ssl::PRFileDesc and prio::PRFileDesc. This keeps the bindings
254 // minimal, but it means that the two forms need casts to translate
255 // between them. ssl::PRFileDesc is left as an opaque type, as the
256 // ssl::SSL_* APIs only need an opaque type.
create_fd(io: &mut Pin<Box<AgentIo>>) -> Res<*mut ssl::PRFileDesc>257 fn create_fd(io: &mut Pin<Box<AgentIo>>) -> Res<*mut ssl::PRFileDesc> {
258 assert_initialized();
259 let label = CString::new("sslwrapper")?;
260 let id = unsafe { prio::PR_GetUniqueIdentity(label.as_ptr()) };
261
262 let base_fd = unsafe { prio::PR_CreateIOLayerStub(id, METHODS) };
263 if base_fd.is_null() {
264 return Err(Error::CreateSslSocket);
265 }
266 let fd = unsafe {
267 (*base_fd).secret = as_c_void(io) as *mut _;
268 ssl::SSL_ImportFD(null_mut(), base_fd as *mut ssl::PRFileDesc)
269 };
270 if fd.is_null() {
271 unsafe { prio::PR_Close(base_fd) };
272 return Err(Error::CreateSslSocket);
273 }
274 Ok(fd)
275 }
276
auth_complete_hook( arg: *mut c_void, _fd: *mut ssl::PRFileDesc, _check_sig: ssl::PRBool, _is_server: ssl::PRBool, ) -> ssl::SECStatus277 unsafe extern "C" fn auth_complete_hook(
278 arg: *mut c_void,
279 _fd: *mut ssl::PRFileDesc,
280 _check_sig: ssl::PRBool,
281 _is_server: ssl::PRBool,
282 ) -> ssl::SECStatus {
283 let auth_required_ptr = arg as *mut bool;
284 *auth_required_ptr = true;
285 // NSS insists on getting SECWouldBlock here rather than accepting
286 // the usual combination of PR_WOULD_BLOCK_ERROR and SECFailure.
287 ssl::_SECStatus_SECWouldBlock
288 }
289
alert_sent_cb( fd: *const ssl::PRFileDesc, arg: *mut c_void, alert: *const ssl::SSLAlert, )290 unsafe extern "C" fn alert_sent_cb(
291 fd: *const ssl::PRFileDesc,
292 arg: *mut c_void,
293 alert: *const ssl::SSLAlert,
294 ) {
295 let alert = alert.as_ref().unwrap();
296 if alert.level == 2 {
297 // Fatal alerts demand attention.
298 let p = arg as *mut Option<Alert>;
299 let st = p.as_mut().unwrap();
300 if st.is_none() {
301 *st = Some(alert.description);
302 } else {
303 qwarn!(
304 [format!("{:p}", fd)],
305 "duplicate alert {}",
306 alert.description
307 );
308 }
309 }
310 }
311
312 // Ready this for connecting.
ready(&mut self, is_server: bool) -> Res<()>313 fn ready(&mut self, is_server: bool) -> Res<()> {
314 secstatus_to_res(unsafe {
315 ssl::SSL_AuthCertificateHook(
316 self.fd,
317 Some(Self::auth_complete_hook),
318 as_c_void(&mut self.auth_required),
319 )
320 })?;
321
322 secstatus_to_res(unsafe {
323 ssl::SSL_AlertSentCallback(
324 self.fd,
325 Some(Self::alert_sent_cb),
326 as_c_void(&mut self.alert),
327 )
328 })?;
329
330 self.now.bind(self.fd)?;
331 self.configure()?;
332 secstatus_to_res(unsafe { ssl::SSL_ResetHandshake(self.fd, is_server as ssl::PRBool) })
333 }
334
335 /// Default configuration.
336 ///
337 /// # Errors
338 /// If `set_version_range` fails.
configure(&mut self) -> Res<()>339 fn configure(&mut self) -> Res<()> {
340 self.set_version_range(TLS_VERSION_1_3, TLS_VERSION_1_3)?;
341 self.set_option(ssl::Opt::Locking, false)?;
342 self.set_option(ssl::Opt::Tickets, false)?;
343 self.set_option(ssl::Opt::OcspStapling, true)?;
344 Ok(())
345 }
346
347 /// Set the versions that are supported.
348 ///
349 /// # Errors
350 /// If the range of versions isn't supported.
set_version_range(&mut self, min: Version, max: Version) -> Res<()>351 pub fn set_version_range(&mut self, min: Version, max: Version) -> Res<()> {
352 let range = ssl::SSLVersionRange {
353 min: min as ssl::PRUint16,
354 max: max as ssl::PRUint16,
355 };
356 secstatus_to_res(unsafe { ssl::SSL_VersionRangeSet(self.fd, &range) })
357 }
358
359 /// Enable a set of ciphers. Note that the order of these is not respected.
360 ///
361 /// # Errors
362 /// If NSS can't enable or disable ciphers.
enable_ciphers(&mut self, ciphers: &[Cipher]) -> Res<()>363 pub fn enable_ciphers(&mut self, ciphers: &[Cipher]) -> Res<()> {
364 let all_ciphers = unsafe { ssl::SSL_GetImplementedCiphers() };
365 let cipher_count = unsafe { ssl::SSL_GetNumImplementedCiphers() } as usize;
366 for i in 0..cipher_count {
367 let p = all_ciphers.wrapping_add(i);
368 secstatus_to_res(unsafe {
369 ssl::SSL_CipherPrefSet(self.fd, i32::from(*p), false as ssl::PRBool)
370 })?;
371 }
372
373 for c in ciphers {
374 secstatus_to_res(unsafe {
375 ssl::SSL_CipherPrefSet(self.fd, i32::from(*c), true as ssl::PRBool)
376 })?;
377 }
378 Ok(())
379 }
380
381 /// Set key exchange groups.
382 ///
383 /// # Errors
384 /// If the underlying API fails (which shouldn't happen).
set_groups(&mut self, groups: &[Group]) -> Res<()>385 pub fn set_groups(&mut self, groups: &[Group]) -> Res<()> {
386 // SSLNamedGroup is a different size to Group, so copy one by one.
387 let group_vec: Vec<_> = groups
388 .iter()
389 .map(|&g| ssl::SSLNamedGroup::Type::from(g))
390 .collect();
391
392 let ptr = group_vec.as_slice().as_ptr();
393 secstatus_to_res(unsafe {
394 ssl::SSL_NamedGroupConfig(self.fd, ptr, c_uint::try_from(group_vec.len())?)
395 })
396 }
397
398 /// Set TLS options.
399 ///
400 /// # Errors
401 /// Returns an error if the option or option value is invalid; i.e., never.
set_option(&mut self, opt: ssl::Opt, value: bool) -> Res<()>402 pub fn set_option(&mut self, opt: ssl::Opt, value: bool) -> Res<()> {
403 secstatus_to_res(unsafe {
404 ssl::SSL_OptionSet(self.fd, opt.as_int(), opt.map_enabled(value))
405 })
406 }
407
408 /// Enable 0-RTT.
409 ///
410 /// # Errors
411 /// See `set_option`.
enable_0rtt(&mut self) -> Res<()>412 pub fn enable_0rtt(&mut self) -> Res<()> {
413 self.set_option(ssl::Opt::EarlyData, true)
414 }
415
416 /// Disable the `EndOfEarlyData` message.
disable_end_of_early_data(&mut self)417 pub fn disable_end_of_early_data(&mut self) {
418 self.no_eoed = true;
419 }
420
421 /// `set_alpn` sets a list of preferred protocols, starting with the most preferred.
422 /// Though ALPN [RFC7301] permits octet sequences, this only allows for UTF-8-encoded
423 /// strings.
424 ///
425 /// This asserts if no items are provided, or if any individual item is longer than
426 /// 255 octets in length.
427 ///
428 /// # Errors
429 /// This should always panic rather than return an error.
set_alpn(&mut self, protocols: &[impl AsRef<str>]) -> Res<()>430 pub fn set_alpn(&mut self, protocols: &[impl AsRef<str>]) -> Res<()> {
431 // Validate and set length.
432 let mut encoded_len = protocols.len();
433 for v in protocols {
434 assert!(v.as_ref().len() < 256);
435 encoded_len += v.as_ref().len();
436 }
437
438 // Prepare to encode.
439 let mut encoded = Vec::with_capacity(encoded_len);
440 let mut add = |v: &str| {
441 if let Ok(s) = u8::try_from(v.len()) {
442 encoded.push(s);
443 encoded.extend_from_slice(v.as_bytes());
444 }
445 };
446
447 // NSS inherited an idiosyncratic API as a result of having implemented NPN
448 // before ALPN. For that reason, we need to put the "best" option last.
449 let (first, rest) = protocols
450 .split_first()
451 .expect("at least one ALPN value needed");
452 for v in rest {
453 add(v.as_ref());
454 }
455 add(first.as_ref());
456 assert_eq!(encoded_len, encoded.len());
457
458 // Now give the result to NSS.
459 secstatus_to_res(unsafe {
460 ssl::SSL_SetNextProtoNego(
461 self.fd,
462 encoded.as_slice().as_ptr(),
463 c_uint::try_from(encoded.len())?,
464 )
465 })
466 }
467
468 /// Install an extension handler.
469 ///
470 /// This can be called multiple times with different values for `ext`. The handler is provided as
471 /// Rc<RefCell<>> so that the caller is able to hold a reference to the handler and later access any
472 /// state that it accumulates.
473 ///
474 /// # Errors
475 /// When the extension handler can't be successfully installed.
extension_handler( &mut self, ext: Extension, handler: Rc<RefCell<dyn ExtensionHandler>>, ) -> Res<()>476 pub fn extension_handler(
477 &mut self,
478 ext: Extension,
479 handler: Rc<RefCell<dyn ExtensionHandler>>,
480 ) -> Res<()> {
481 let tracker = unsafe { ExtensionTracker::new(self.fd, ext, handler) }?;
482 self.extension_handlers.push(tracker);
483 Ok(())
484 }
485
486 // This function tracks whether handshake() or handshake_raw() was used
487 // and prevents the other from being used.
set_raw(&mut self, r: bool) -> Res<()>488 fn set_raw(&mut self, r: bool) -> Res<()> {
489 if self.raw.is_none() {
490 self.secrets.register(self.fd)?;
491 self.raw = Some(r);
492 Ok(())
493 } else if self.raw.unwrap() == r {
494 Ok(())
495 } else {
496 Err(Error::MixedHandshakeMethod)
497 }
498 }
499
500 /// Get information about the connection.
501 /// This includes the version, ciphersuite, and ALPN.
502 ///
503 /// Calling this function returns None until the connection is complete.
504 #[must_use]
info(&self) -> Option<&SecretAgentInfo>505 pub fn info(&self) -> Option<&SecretAgentInfo> {
506 match self.state {
507 HandshakeState::Complete(ref info) => Some(info),
508 _ => None,
509 }
510 }
511
512 /// Get any preliminary information about the status of the connection.
513 ///
514 /// This includes whether 0-RTT was accepted and any information related to that.
515 /// Calling this function collects all the relevant information.
516 ///
517 /// # Errors
518 /// When the underlying socket functions fail.
preinfo(&self) -> Res<SecretAgentPreInfo>519 pub fn preinfo(&self) -> Res<SecretAgentPreInfo> {
520 SecretAgentPreInfo::new(self.fd)
521 }
522
523 /// Get the peer's certificate chain.
524 #[must_use]
peer_certificate(&self) -> Option<CertificateInfo>525 pub fn peer_certificate(&self) -> Option<CertificateInfo> {
526 CertificateInfo::new(self.fd)
527 }
528
529 /// Return any fatal alert that the TLS stack might have sent.
530 #[must_use]
alert(&self) -> Option<&Alert>531 pub fn alert(&self) -> Option<&Alert> {
532 (&*self.alert).as_ref()
533 }
534
535 /// Call this function to mark the peer as authenticated.
536 /// Only call this function if `handshake/handshake_raw` returns
537 /// `HandshakeState::AuthenticationPending`, or it will panic.
authenticated(&mut self, status: AuthenticationStatus)538 pub fn authenticated(&mut self, status: AuthenticationStatus) {
539 assert_eq!(self.state, HandshakeState::AuthenticationPending);
540 *self.auth_required = false;
541 self.state = HandshakeState::Authenticated(status.into());
542 }
543
capture_error<T>(&mut self, res: Res<T>) -> Res<T>544 fn capture_error<T>(&mut self, res: Res<T>) -> Res<T> {
545 if let Err(e) = &res {
546 qwarn!([self], "error: {:?}", e);
547 self.state = HandshakeState::Failed(e.clone());
548 }
549 res
550 }
551
update_state(&mut self, res: Res<()>) -> Res<()>552 fn update_state(&mut self, res: Res<()>) -> Res<()> {
553 self.state = if is_blocked(&res) {
554 if *self.auth_required {
555 HandshakeState::AuthenticationPending
556 } else {
557 HandshakeState::InProgress
558 }
559 } else {
560 self.capture_error(res)?;
561 let info = self.capture_error(SecretAgentInfo::new(self.fd))?;
562 HandshakeState::Complete(info)
563 };
564 qinfo!([self], "state -> {:?}", self.state);
565 Ok(())
566 }
567
568 /// Drive the TLS handshake, taking bytes from `input` and putting
569 /// any bytes necessary into `output`.
570 /// This takes the current time as `now`.
571 /// On success a tuple of a `HandshakeState` and usize indicate whether the handshake
572 /// is complete and how many bytes were written to `output`, respectively.
573 /// If the state is `HandshakeState::AuthenticationPending`, then ONLY call this
574 /// function if you want to proceed, because this will mark the certificate as OK.
575 ///
576 /// # Errors
577 /// When the handshake fails this returns an error.
handshake(&mut self, now: Instant, input: &[u8]) -> Res<Vec<u8>>578 pub fn handshake(&mut self, now: Instant, input: &[u8]) -> Res<Vec<u8>> {
579 self.now.set(now)?;
580 self.set_raw(false)?;
581
582 let rv = {
583 // Within this scope, _h maintains a mutable reference to self.io.
584 let _h = self.io.wrap(input);
585 match self.state {
586 HandshakeState::Authenticated(ref err) => unsafe {
587 ssl::SSL_AuthCertificateComplete(self.fd, *err)
588 },
589 _ => unsafe { ssl::SSL_ForceHandshake(self.fd) },
590 }
591 };
592 // Take before updating state so that we leave the output buffer empty
593 // even if there is an error.
594 let output = self.io.take_output();
595 self.update_state(secstatus_to_res(rv))?;
596 Ok(output)
597 }
598
599 /// Setup to receive records for raw handshake functions.
setup_raw(&mut self) -> Res<Pin<Box<RecordList>>>600 fn setup_raw(&mut self) -> Res<Pin<Box<RecordList>>> {
601 self.set_raw(true)?;
602 self.capture_error(RecordList::setup(self.fd))
603 }
604
inject_eoed(&mut self) -> Res<()>605 fn inject_eoed(&mut self) -> Res<()> {
606 // EndOfEarlyData is as follows:
607 // struct {
608 // HandshakeType msg_type = end_of_early_data(5);
609 // uint24 length = 0;
610 // };
611 const END_OF_EARLY_DATA: &[u8] = &[5, 0, 0, 0];
612
613 if self.no_eoed {
614 let mut read_epoch: u16 = 0;
615 unsafe { ssl::SSL_GetCurrentEpoch(self.fd, &mut read_epoch, null_mut()) }?;
616 if read_epoch == 1 {
617 // It's waiting for EndOfEarlyData, so feed one in.
618 // Note that this is the test that ensures that we only do this for the server.
619 let eoed = Record::new(1, 22, END_OF_EARLY_DATA);
620 self.capture_error(eoed.write(self.fd))?;
621 self.no_eoed = false;
622 }
623 }
624 Ok(())
625 }
626
627 /// Drive the TLS handshake, but get the raw content of records, not
628 /// protected records as bytes. This function is incompatible with
629 /// `handshake()`; use either this or `handshake()` exclusively.
630 ///
631 /// Ideally, this only includes records from the current epoch.
632 /// If you send data from multiple epochs, you might end up being sad.
633 ///
634 /// # Errors
635 /// When the handshake fails this returns an error.
handshake_raw(&mut self, now: Instant, input: Option<Record>) -> Res<RecordList>636 pub fn handshake_raw(&mut self, now: Instant, input: Option<Record>) -> Res<RecordList> {
637 self.now.set(now)?;
638 let mut records = self.setup_raw()?;
639
640 // Fire off any authentication we might need to complete.
641 if let HandshakeState::Authenticated(ref err) = self.state {
642 let result =
643 secstatus_to_res(unsafe { ssl::SSL_AuthCertificateComplete(self.fd, *err) });
644 qdebug!([self], "SSL_AuthCertificateComplete: {:?}", result);
645 // This should return SECSuccess, so don't use update_state().
646 self.capture_error(result)?;
647 }
648
649 // Feed in any records.
650 if let Some(rec) = input {
651 if rec.epoch == 2 {
652 self.inject_eoed()?;
653 }
654 self.capture_error(rec.write(self.fd))?;
655 }
656
657 // Drive the handshake once more.
658 let rv = secstatus_to_res(unsafe { ssl::SSL_ForceHandshake(self.fd) });
659 self.update_state(rv)?;
660
661 if self.no_eoed {
662 records.remove_eoed();
663 }
664
665 Ok(*Pin::into_inner(records))
666 }
667
close(&mut self)668 pub fn close(&mut self) {
669 // It should be safe to close multiple times.
670 if self.fd.is_null() {
671 return;
672 }
673 if let Some(true) = self.raw {
674 // Need to hold the record list in scope until the close is done.
675 let _records = self.setup_raw().expect("Can only close");
676 unsafe { prio::PR_Close(self.fd as *mut prio::PRFileDesc) };
677 } else {
678 // Need to hold the IO wrapper in scope until the close is done.
679 let _io = self.io.wrap(&[]);
680 unsafe { prio::PR_Close(self.fd as *mut prio::PRFileDesc) };
681 };
682 let _output = self.io.take_output();
683 self.fd = null_mut();
684 }
685
686 /// State returns the status of the handshake.
687 #[must_use]
state(&self) -> &HandshakeState688 pub fn state(&self) -> &HandshakeState {
689 &self.state
690 }
691 /// Take a read secret. This will only return a non-`None` value once.
692 #[must_use]
read_secret(&mut self, epoch: Epoch) -> Option<p11::SymKey>693 pub fn read_secret(&mut self, epoch: Epoch) -> Option<p11::SymKey> {
694 self.secrets.take_read(epoch)
695 }
696 /// Take a write secret.
697 #[must_use]
write_secret(&mut self, epoch: Epoch) -> Option<p11::SymKey>698 pub fn write_secret(&mut self, epoch: Epoch) -> Option<p11::SymKey> {
699 self.secrets.take_write(epoch)
700 }
701 }
702
703 impl Drop for SecretAgent {
drop(&mut self)704 fn drop(&mut self) {
705 self.close();
706 }
707 }
708
709 impl ::std::fmt::Display for SecretAgent {
fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result710 fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
711 write!(f, "Agent {:p}", self.fd)
712 }
713 }
714
715 /// A TLS Client.
716 #[derive(Debug)]
717 pub struct Client {
718 agent: SecretAgent,
719
720 /// Records the last resumption token.
721 resumption: Pin<Box<Option<Vec<u8>>>>,
722 }
723
724 impl Client {
725 /// Create a new client agent.
726 ///
727 /// # Errors
728 /// Errors returned if the socket can't be created or configured.
new(server_name: &str) -> Res<Self>729 pub fn new(server_name: &str) -> Res<Self> {
730 let mut agent = SecretAgent::new()?;
731 let url = CString::new(server_name)?;
732 secstatus_to_res(unsafe { ssl::SSL_SetURL(agent.fd, url.as_ptr()) })?;
733 agent.ready(false)?;
734 let mut client = Self {
735 agent,
736 resumption: Box::pin(None),
737 };
738 client.ready()?;
739 Ok(client)
740 }
741
resumption_token_cb( fd: *mut ssl::PRFileDesc, token: *const u8, len: c_uint, arg: *mut c_void, ) -> ssl::SECStatus742 unsafe extern "C" fn resumption_token_cb(
743 fd: *mut ssl::PRFileDesc,
744 token: *const u8,
745 len: c_uint,
746 arg: *mut c_void,
747 ) -> ssl::SECStatus {
748 let resumption_ptr = arg as *mut Option<Vec<u8>>;
749 let resumption = resumption_ptr.as_mut().unwrap();
750 let mut v = Vec::with_capacity(len as usize);
751 v.extend_from_slice(std::slice::from_raw_parts(token, len as usize));
752 qdebug!([format!("{:p}", fd)], "Got resumption token");
753 *resumption = Some(v);
754 ssl::SECSuccess
755 }
756
ready(&mut self) -> Res<()>757 fn ready(&mut self) -> Res<()> {
758 let fd = self.fd;
759 unsafe {
760 ssl::SSL_SetResumptionTokenCallback(
761 fd,
762 Some(Self::resumption_token_cb),
763 as_c_void(&mut self.resumption),
764 )
765 }
766 }
767
768 /// Return the resumption token.
769 #[must_use]
resumption_token(&self) -> Option<&Vec<u8>>770 pub fn resumption_token(&self) -> Option<&Vec<u8>> {
771 (*self.resumption).as_ref()
772 }
773
774 /// Enable resumption, using a token previously provided.
775 ///
776 /// # Errors
777 /// Error returned when the resumption token is invalid or
778 /// the socket is not able to use the value.
set_resumption_token(&mut self, token: &[u8]) -> Res<()>779 pub fn set_resumption_token(&mut self, token: &[u8]) -> Res<()> {
780 unsafe {
781 ssl::SSL_SetResumptionToken(
782 self.agent.fd,
783 token.as_ptr(),
784 c_uint::try_from(token.len())?,
785 )
786 }
787 }
788 }
789
790 impl Deref for Client {
791 type Target = SecretAgent;
792 #[must_use]
deref(&self) -> &SecretAgent793 fn deref(&self) -> &SecretAgent {
794 &self.agent
795 }
796 }
797
798 impl DerefMut for Client {
deref_mut(&mut self) -> &mut SecretAgent799 fn deref_mut(&mut self) -> &mut SecretAgent {
800 &mut self.agent
801 }
802 }
803
804 /// `ZeroRttCheckResult` encapsulates the options for handling a `ClientHello`.
805 #[derive(Clone, Debug, PartialEq)]
806 pub enum ZeroRttCheckResult {
807 /// Accept 0-RTT; the default.
808 Accept,
809 /// Reject 0-RTT, but continue the handshake normally.
810 Reject,
811 /// Send HelloRetryRequest (probably not needed for QUIC).
812 HelloRetryRequest(Vec<u8>),
813 /// Fail the handshake.
814 Fail,
815 }
816
817 /// A `ZeroRttChecker` is used by the agent to validate the application token (as provided by `send_ticket`)
818 pub trait ZeroRttChecker: std::fmt::Debug + std::marker::Unpin {
check(&self, token: &[u8]) -> ZeroRttCheckResult819 fn check(&self, token: &[u8]) -> ZeroRttCheckResult;
820 }
821
822 #[derive(Debug)]
823 struct ZeroRttCheckState {
824 fd: *mut ssl::PRFileDesc,
825 checker: Pin<Box<dyn ZeroRttChecker>>,
826 }
827
828 impl ZeroRttCheckState {
new(fd: *mut ssl::PRFileDesc, checker: Box<dyn ZeroRttChecker>) -> Self829 pub fn new(fd: *mut ssl::PRFileDesc, checker: Box<dyn ZeroRttChecker>) -> Self {
830 Self {
831 fd,
832 checker: Pin::new(checker),
833 }
834 }
835 }
836
837 #[derive(Debug)]
838 pub struct Server {
839 agent: SecretAgent,
840 /// This holds the HRR callback context.
841 zero_rtt_check: Option<Pin<Box<ZeroRttCheckState>>>,
842 }
843
844 impl Server {
845 /// Create a new server agent.
846 ///
847 /// # Errors
848 /// Errors returned when NSS fails.
new(certificates: &[impl AsRef<str>]) -> Res<Self>849 pub fn new(certificates: &[impl AsRef<str>]) -> Res<Self> {
850 let mut agent = SecretAgent::new()?;
851
852 for n in certificates {
853 let c = CString::new(n.as_ref())?;
854 let cert = match NonNull::new(unsafe {
855 p11::PK11_FindCertFromNickname(c.as_ptr(), null_mut())
856 }) {
857 None => return Err(Error::CertificateLoading),
858 Some(ptr) => p11::Certificate::new(ptr),
859 };
860 let key = match NonNull::new(unsafe {
861 p11::PK11_FindKeyByAnyCert(*cert.deref(), null_mut())
862 }) {
863 None => return Err(Error::CertificateLoading),
864 Some(ptr) => p11::PrivateKey::new(ptr),
865 };
866 secstatus_to_res(unsafe {
867 ssl::SSL_ConfigServerCert(agent.fd, *cert.deref(), *key.deref(), null(), 0)
868 })?;
869 }
870
871 agent.ready(true)?;
872 Ok(Self {
873 agent,
874 zero_rtt_check: None,
875 })
876 }
877
hello_retry_cb( first_hello: PRBool, client_token: *const u8, client_token_len: c_uint, retry_token: *mut u8, retry_token_len: *mut c_uint, retry_token_max: c_uint, arg: *mut c_void, ) -> ssl::SSLHelloRetryRequestAction::Type878 unsafe extern "C" fn hello_retry_cb(
879 first_hello: PRBool,
880 client_token: *const u8,
881 client_token_len: c_uint,
882 retry_token: *mut u8,
883 retry_token_len: *mut c_uint,
884 retry_token_max: c_uint,
885 arg: *mut c_void,
886 ) -> ssl::SSLHelloRetryRequestAction::Type {
887 if first_hello == 0 {
888 // On the second ClientHello after HelloRetryRequest, skip checks.
889 return ssl::SSLHelloRetryRequestAction::ssl_hello_retry_accept;
890 }
891
892 let p = arg as *mut ZeroRttCheckState;
893 let check_state = p.as_mut().unwrap();
894 let token = if client_token.is_null() {
895 &[]
896 } else {
897 std::slice::from_raw_parts(client_token, client_token_len as usize)
898 };
899 match check_state.checker.check(token) {
900 ZeroRttCheckResult::Accept => ssl::SSLHelloRetryRequestAction::ssl_hello_retry_accept,
901 ZeroRttCheckResult::Fail => ssl::SSLHelloRetryRequestAction::ssl_hello_retry_fail,
902 ZeroRttCheckResult::Reject => {
903 ssl::SSLHelloRetryRequestAction::ssl_hello_retry_reject_0rtt
904 }
905 ZeroRttCheckResult::HelloRetryRequest(tok) => {
906 // Don't bother propagating errors from this, because it should be caught in testing.
907 assert!(tok.len() <= usize::try_from(retry_token_max).unwrap());
908 let slc = std::slice::from_raw_parts_mut(retry_token, tok.len());
909 slc.copy_from_slice(&tok);
910 *retry_token_len = c_uint::try_from(tok.len()).expect("token was way too big");
911 ssl::SSLHelloRetryRequestAction::ssl_hello_retry_request
912 }
913 }
914 }
915
916 /// Enable 0-RTT. This shadows the function of the same name that can be accessed
917 /// via the Deref implementation on Server.
918 ///
919 /// # Errors
920 /// Returns an error if the underlying NSS functions fail.
enable_0rtt( &mut self, anti_replay: &AntiReplay, max_early_data: u32, checker: Box<dyn ZeroRttChecker>, ) -> Res<()>921 pub fn enable_0rtt(
922 &mut self,
923 anti_replay: &AntiReplay,
924 max_early_data: u32,
925 checker: Box<dyn ZeroRttChecker>,
926 ) -> Res<()> {
927 let mut check_state = Box::pin(ZeroRttCheckState::new(self.agent.fd, checker));
928 unsafe {
929 ssl::SSL_HelloRetryRequestCallback(
930 self.agent.fd,
931 Some(Self::hello_retry_cb),
932 as_c_void(&mut check_state),
933 )
934 }?;
935 unsafe { ssl::SSL_SetMaxEarlyDataSize(self.agent.fd, max_early_data) }?;
936 self.zero_rtt_check = Some(check_state);
937 self.agent.enable_0rtt()?;
938 anti_replay.config_socket(self.fd)?;
939 Ok(())
940 }
941
942 /// Send a session ticket to the client.
943 /// This adds |extra| application-specific content into that ticket.
944 /// The records that are sent are captured and returned.
945 ///
946 /// # Errors
947 /// If NSS is unable to send a ticket, or if this agent is incorrectly configured.
send_ticket(&mut self, now: Instant, extra: &[u8]) -> Res<RecordList>948 pub fn send_ticket(&mut self, now: Instant, extra: &[u8]) -> Res<RecordList> {
949 self.agent.now.set(now)?;
950 let records = self.setup_raw()?;
951
952 unsafe {
953 ssl::SSL_SendSessionTicket(self.fd, extra.as_ptr(), c_uint::try_from(extra.len())?)
954 }?;
955
956 Ok(*Pin::into_inner(records))
957 }
958 }
959
960 impl Deref for Server {
961 type Target = SecretAgent;
962 #[must_use]
deref(&self) -> &SecretAgent963 fn deref(&self) -> &SecretAgent {
964 &self.agent
965 }
966 }
967
968 impl DerefMut for Server {
deref_mut(&mut self) -> &mut SecretAgent969 fn deref_mut(&mut self) -> &mut SecretAgent {
970 &mut self.agent
971 }
972 }
973
974 /// A generic container for Client or Server.
975 #[derive(Debug)]
976 pub enum Agent {
977 Client(crate::agent::Client),
978 Server(crate::agent::Server),
979 }
980
981 impl Deref for Agent {
982 type Target = SecretAgent;
983 #[must_use]
deref(&self) -> &SecretAgent984 fn deref(&self) -> &SecretAgent {
985 match self {
986 Self::Client(c) => &*c,
987 Self::Server(s) => &*s,
988 }
989 }
990 }
991
992 impl DerefMut for Agent {
deref_mut(&mut self) -> &mut SecretAgent993 fn deref_mut(&mut self) -> &mut SecretAgent {
994 match self {
995 Self::Client(c) => c.deref_mut(),
996 Self::Server(s) => s.deref_mut(),
997 }
998 }
999 }
1000
1001 impl From<Client> for Agent {
1002 #[must_use]
from(c: Client) -> Self1003 fn from(c: Client) -> Self {
1004 Self::Client(c)
1005 }
1006 }
1007
1008 impl From<Server> for Agent {
1009 #[must_use]
from(s: Server) -> Self1010 fn from(s: Server) -> Self {
1011 Self::Server(s)
1012 }
1013 }
1014