1 #[cfg(feature = "risky-raw-split")] 2 use crate::constants::{CIPHERKEYLEN, MAXHASHLEN}; 3 #[cfg(feature = "hfs")] 4 use crate::constants::{MAXKEMCTLEN, MAXKEMPUBLEN, MAXKEMSSLEN}; 5 #[cfg(feature = "hfs")] 6 use crate::types::Kem; 7 use crate::{ 8 cipherstate::{CipherState, CipherStates}, 9 constants::{MAXDHLEN, MAXMSGLEN, PSKLEN, TAGLEN}, 10 error::{Error, InitStage, StateProblem}, 11 params::{DhToken, HandshakeTokens, MessagePatterns, NoiseParams, Token}, 12 stateless_transportstate::StatelessTransportState, 13 symmetricstate::SymmetricState, 14 transportstate::TransportState, 15 types::{Dh, Hash, Random}, 16 utils::Toggle, 17 }; 18 use std::{ 19 convert::{TryFrom, TryInto}, 20 fmt, 21 }; 22 23 /// A state machine encompassing the handshake phase of a Noise session. 24 /// 25 /// **Note:** you are probably looking for [`Builder`](struct.Builder.html) to 26 /// get started. 27 /// 28 /// See: [http://noiseprotocol.org/noise.html#the-handshakestate-object](http://noiseprotocol.org/noise.html#the-handshakestate-object) 29 pub struct HandshakeState { 30 pub(crate) rng: Box<dyn Random>, 31 pub(crate) symmetricstate: SymmetricState, 32 pub(crate) cipherstates: CipherStates, 33 pub(crate) s: Toggle<Box<dyn Dh>>, 34 pub(crate) e: Toggle<Box<dyn Dh>>, 35 pub(crate) fixed_ephemeral: bool, 36 pub(crate) rs: Toggle<[u8; MAXDHLEN]>, 37 pub(crate) re: Toggle<[u8; MAXDHLEN]>, 38 pub(crate) initiator: bool, 39 pub(crate) params: NoiseParams, 40 pub(crate) psks: [Option<[u8; PSKLEN]>; 10], 41 #[cfg(feature = "hfs")] 42 pub(crate) kem: Option<Box<dyn Kem>>, 43 #[cfg(feature = "hfs")] 44 pub(crate) kem_re: Option<[u8; MAXKEMPUBLEN]>, 45 pub(crate) my_turn: bool, 46 pub(crate) message_patterns: MessagePatterns, 47 pub(crate) pattern_position: usize, 48 } 49 50 impl HandshakeState { 51 #[allow(clippy::too_many_arguments)] new( rng: Box<dyn Random>, cipherstate: CipherState, hasher: Box<dyn Hash>, s: Toggle<Box<dyn Dh>>, e: Toggle<Box<dyn Dh>>, fixed_ephemeral: bool, rs: Toggle<[u8; MAXDHLEN]>, re: Toggle<[u8; MAXDHLEN]>, initiator: bool, params: NoiseParams, psks: [Option<[u8; PSKLEN]>; 10], prologue: &[u8], cipherstates: CipherStates, ) -> Result<HandshakeState, Error>52 pub(crate) fn new( 53 rng: Box<dyn Random>, 54 cipherstate: CipherState, 55 hasher: Box<dyn Hash>, 56 s: Toggle<Box<dyn Dh>>, 57 e: Toggle<Box<dyn Dh>>, 58 fixed_ephemeral: bool, 59 rs: Toggle<[u8; MAXDHLEN]>, 60 re: Toggle<[u8; MAXDHLEN]>, 61 initiator: bool, 62 params: NoiseParams, 63 psks: [Option<[u8; PSKLEN]>; 10], 64 prologue: &[u8], 65 cipherstates: CipherStates, 66 ) -> Result<HandshakeState, Error> { 67 if (s.is_on() && e.is_on() && s.pub_len() != e.pub_len()) 68 || (s.is_on() && rs.is_on() && s.pub_len() > rs.len()) 69 || (s.is_on() && re.is_on() && s.pub_len() > re.len()) 70 { 71 bail!(InitStage::ValidateKeyLengths); 72 } 73 74 let tokens = HandshakeTokens::try_from(¶ms.handshake)?; 75 76 let mut symmetricstate = SymmetricState::new(cipherstate, hasher); 77 78 symmetricstate.initialize(¶ms.name); 79 symmetricstate.mix_hash(prologue); 80 81 let dh_len = s.pub_len(); 82 if initiator { 83 for token in tokens.premsg_pattern_i { 84 symmetricstate.mix_hash( 85 match *token { 86 Token::S => &s, 87 Token::E => &e, 88 _ => unreachable!(), 89 } 90 .get() 91 .ok_or(StateProblem::MissingKeyMaterial)? 92 .pubkey(), 93 ); 94 } 95 for token in tokens.premsg_pattern_r { 96 symmetricstate.mix_hash( 97 &match *token { 98 Token::S => &rs, 99 Token::E => &re, 100 _ => unreachable!(), 101 } 102 .get() 103 .ok_or(StateProblem::MissingKeyMaterial)?[..dh_len], 104 ); 105 } 106 } else { 107 for token in tokens.premsg_pattern_i { 108 symmetricstate.mix_hash( 109 &match *token { 110 Token::S => &rs, 111 Token::E => &re, 112 _ => unreachable!(), 113 } 114 .get() 115 .ok_or(StateProblem::MissingKeyMaterial)?[..dh_len], 116 ); 117 } 118 for token in tokens.premsg_pattern_r { 119 symmetricstate.mix_hash( 120 match *token { 121 Token::S => &s, 122 Token::E => &e, 123 _ => unreachable!(), 124 } 125 .get() 126 .ok_or(StateProblem::MissingKeyMaterial)? 127 .pubkey(), 128 ); 129 } 130 } 131 132 Ok(HandshakeState { 133 rng, 134 symmetricstate, 135 cipherstates, 136 s, 137 e, 138 fixed_ephemeral, 139 rs, 140 re, 141 initiator, 142 params, 143 psks, 144 #[cfg(feature = "hfs")] 145 kem: None, 146 #[cfg(feature = "hfs")] 147 kem_re: None, 148 my_turn: initiator, 149 message_patterns: tokens.msg_patterns, 150 pattern_position: 0, 151 }) 152 } 153 dh_len(&self) -> usize154 pub(crate) fn dh_len(&self) -> usize { 155 self.s.pub_len() 156 } 157 158 #[cfg(feature = "hfs")] set_kem(&mut self, kem: Box<dyn Kem>)159 pub(crate) fn set_kem(&mut self, kem: Box<dyn Kem>) { 160 self.kem = Some(kem); 161 } 162 dh(&self, token: &DhToken) -> Result<[u8; MAXDHLEN], Error>163 fn dh(&self, token: &DhToken) -> Result<[u8; MAXDHLEN], Error> { 164 let mut dh_out = [0u8; MAXDHLEN]; 165 let (dh, key) = match (token, self.is_initiator()) { 166 (DhToken::Ee, _) => (&self.e, &self.re), 167 (DhToken::Ss, _) => (&self.s, &self.rs), 168 (DhToken::Se, true) | (DhToken::Es, false) => (&self.s, &self.re), 169 (DhToken::Es, true) | (DhToken::Se, false) => (&self.e, &self.rs), 170 }; 171 if !(dh.is_on() && key.is_on()) { 172 bail!(StateProblem::MissingKeyMaterial); 173 } 174 dh.dh(&**key, &mut dh_out).map_err(|_| Error::Dh)?; 175 Ok(dh_out) 176 } 177 178 /// This method will return `true` if the *previous* write payload was encrypted. 179 /// 180 /// See [Payload Security Properties](http://noiseprotocol.org/noise.html#payload-security-properties) 181 /// for more information on the specific properties of your chosen handshake pattern. 182 /// 183 /// # Examples 184 /// 185 /// ```rust,ignore 186 /// let mut session = Builder::new("Noise_NN_25519_AESGCM_SHA256".parse()?) 187 /// .build_initiator()?; 188 /// 189 /// // write message... 190 /// 191 /// assert!(session.was_write_payload_encrypted()); 192 /// ``` was_write_payload_encrypted(&self) -> bool193 pub fn was_write_payload_encrypted(&self) -> bool { 194 self.symmetricstate.has_key() 195 } 196 197 /// Construct a message from `payload` (and pending handshake tokens if in handshake state), 198 /// and writes it to the `message` buffer. 199 /// 200 /// Returns the size of the written payload. 201 /// 202 /// # Errors 203 /// 204 /// Will result in `Error::Input` if the size of the output exceeds the max message 205 /// length in the Noise Protocol (65535 bytes). write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, Error>206 pub fn write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, Error> { 207 let checkpoint = self.symmetricstate.checkpoint(); 208 match self._write_message(payload, message) { 209 Ok(res) => { 210 self.pattern_position += 1; 211 self.my_turn = false; 212 Ok(res) 213 }, 214 Err(err) => { 215 self.symmetricstate.restore(checkpoint); 216 Err(err) 217 }, 218 } 219 } 220 _write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, Error>221 fn _write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, Error> { 222 if !self.my_turn { 223 bail!(StateProblem::NotTurnToWrite); 224 } else if self.pattern_position >= self.message_patterns.len() { 225 bail!(StateProblem::HandshakeAlreadyFinished); 226 } 227 228 let mut byte_index = 0; 229 for token in self.message_patterns[self.pattern_position].iter() { 230 match token { 231 Token::E => { 232 if byte_index + self.e.pub_len() > message.len() { 233 bail!(Error::Input) 234 } 235 236 if !self.fixed_ephemeral { 237 self.e.generate(&mut *self.rng); 238 } 239 let pubkey = self.e.pubkey(); 240 message[byte_index..byte_index + pubkey.len()].copy_from_slice(pubkey); 241 byte_index += pubkey.len(); 242 self.symmetricstate.mix_hash(pubkey); 243 if self.params.handshake.is_psk() { 244 self.symmetricstate.mix_key(pubkey); 245 } 246 self.e.enable(); 247 }, 248 Token::S => { 249 if !self.s.is_on() { 250 bail!(StateProblem::MissingKeyMaterial); 251 } else if byte_index + self.s.pub_len() > message.len() { 252 bail!(Error::Input) 253 } 254 255 byte_index += self 256 .symmetricstate 257 .encrypt_and_mix_hash(self.s.pubkey(), &mut message[byte_index..])?; 258 }, 259 Token::Psk(n) => match self.psks[*n as usize] { 260 Some(psk) => { 261 self.symmetricstate.mix_key_and_hash(&psk); 262 }, 263 None => { 264 bail!(StateProblem::MissingPsk); 265 }, 266 }, 267 Token::Dh(t) => { 268 let dh_out = self.dh(t)?; 269 self.symmetricstate.mix_key(&dh_out[..self.dh_len()]); 270 }, 271 #[cfg(feature = "hfs")] 272 Token::E1 => { 273 let kem = self.kem.as_mut().ok_or(Error::Input)?; 274 if kem.pub_len() > message.len() { 275 bail!(Error::Input); 276 } 277 278 kem.generate(&mut *self.rng); 279 byte_index += self 280 .symmetricstate 281 .encrypt_and_mix_hash(kem.pubkey(), &mut message[byte_index..])?; 282 }, 283 #[cfg(feature = "hfs")] 284 Token::Ekem1 => { 285 let kem = self.kem.as_mut().unwrap(); 286 let mut kem_output_buf = [0; MAXKEMSSLEN]; 287 let mut ciphertext_buf = [0; MAXKEMCTLEN]; 288 289 if kem.ciphertext_len() > message.len() { 290 bail!(Error::Input); 291 } 292 293 let kem_output = &mut kem_output_buf[..kem.shared_secret_len()]; 294 let ciphertext = &mut ciphertext_buf[..kem.ciphertext_len()]; 295 let pubkey = &self.kem_re.as_ref().unwrap()[..kem.pub_len()]; 296 if kem.encapsulate(pubkey, kem_output, ciphertext).is_err() { 297 bail!(Error::Kem); 298 } 299 300 byte_index += self.symmetricstate.encrypt_and_mix_hash( 301 &ciphertext[..kem.ciphertext_len()], 302 &mut message[byte_index..], 303 )?; 304 self.symmetricstate.mix_key(&kem_output[..kem.shared_secret_len()]); 305 }, 306 } 307 } 308 309 if byte_index + payload.len() + TAGLEN > message.len() { 310 bail!(Error::Input); 311 } 312 byte_index += 313 self.symmetricstate.encrypt_and_mix_hash(payload, &mut message[byte_index..])?; 314 if byte_index > MAXMSGLEN { 315 bail!(Error::Input); 316 } 317 if self.pattern_position == (self.message_patterns.len() - 1) { 318 self.symmetricstate.split(&mut self.cipherstates.0, &mut self.cipherstates.1); 319 } 320 Ok(byte_index) 321 } 322 323 /// Reads a noise message from `input` 324 /// 325 /// Returns the size of the payload written to `payload`. 326 /// 327 /// # Errors 328 /// 329 /// Will result in `Error::Decrypt` if the contents couldn't be decrypted and/or the 330 /// authentication tag didn't verify. 331 /// 332 /// # Panics 333 /// 334 /// This function will panic if there is no key, or if there is a nonce overflow. read_message(&mut self, message: &[u8], payload: &mut [u8]) -> Result<usize, Error>335 pub fn read_message(&mut self, message: &[u8], payload: &mut [u8]) -> Result<usize, Error> { 336 let checkpoint = self.symmetricstate.checkpoint(); 337 match self._read_message(message, payload) { 338 Ok(res) => { 339 self.pattern_position += 1; 340 self.my_turn = true; 341 Ok(res) 342 }, 343 Err(err) => { 344 self.symmetricstate.restore(checkpoint); 345 Err(err) 346 }, 347 } 348 } 349 _read_message(&mut self, message: &[u8], payload: &mut [u8]) -> Result<usize, Error>350 fn _read_message(&mut self, message: &[u8], payload: &mut [u8]) -> Result<usize, Error> { 351 if message.len() > MAXMSGLEN { 352 bail!(Error::Input); 353 } else if self.my_turn { 354 bail!(StateProblem::NotTurnToRead); 355 } else if self.pattern_position >= self.message_patterns.len() { 356 bail!(StateProblem::HandshakeAlreadyFinished); 357 } 358 let last = self.pattern_position == (self.message_patterns.len() - 1); 359 360 let dh_len = self.dh_len(); 361 let mut ptr = message; 362 for token in self.message_patterns[self.pattern_position].iter() { 363 match token { 364 Token::E => { 365 if ptr.len() < dh_len { 366 bail!(Error::Input); 367 } 368 self.re[..dh_len].copy_from_slice(&ptr[..dh_len]); 369 ptr = &ptr[dh_len..]; 370 self.symmetricstate.mix_hash(&self.re[..dh_len]); 371 if self.params.handshake.is_psk() { 372 self.symmetricstate.mix_key(&self.re[..dh_len]); 373 } 374 self.re.enable(); 375 }, 376 Token::S => { 377 let data = if self.symmetricstate.has_key() { 378 if ptr.len() < dh_len + TAGLEN { 379 bail!(Error::Input); 380 } 381 let temp = &ptr[..dh_len + TAGLEN]; 382 ptr = &ptr[dh_len + TAGLEN..]; 383 temp 384 } else { 385 if ptr.len() < dh_len { 386 bail!(Error::Input); 387 } 388 let temp = &ptr[..dh_len]; 389 ptr = &ptr[dh_len..]; 390 temp 391 }; 392 self.symmetricstate 393 .decrypt_and_mix_hash(data, &mut self.rs[..dh_len]) 394 .map_err(|_| Error::Decrypt)?; 395 self.rs.enable(); 396 }, 397 Token::Psk(n) => match self.psks[*n as usize] { 398 Some(psk) => { 399 self.symmetricstate.mix_key_and_hash(&psk); 400 }, 401 None => { 402 bail!(StateProblem::MissingPsk); 403 }, 404 }, 405 Token::Dh(t) => { 406 let dh_out = self.dh(t)?; 407 self.symmetricstate.mix_key(&dh_out[..self.dh_len()]); 408 }, 409 #[cfg(feature = "hfs")] 410 Token::E1 => { 411 let kem = self.kem.as_ref().ok_or(Error::Kem)?; 412 let read_len = if self.symmetricstate.has_key() { 413 kem.pub_len() + TAGLEN 414 } else { 415 kem.pub_len() 416 }; 417 if ptr.len() < read_len { 418 bail!(Error::Input); 419 } 420 let mut kem_re = [0; MAXKEMPUBLEN]; 421 self.symmetricstate 422 .decrypt_and_mix_hash(&ptr[..read_len], &mut kem_re[..kem.pub_len()]) 423 .map_err(|_| Error::Decrypt)?; 424 self.kem_re = Some(kem_re); 425 ptr = &ptr[read_len..]; 426 }, 427 #[cfg(feature = "hfs")] 428 Token::Ekem1 => { 429 let kem = self.kem.as_ref().unwrap(); 430 let read_len = if self.symmetricstate.has_key() { 431 kem.ciphertext_len() + TAGLEN 432 } else { 433 kem.ciphertext_len() 434 }; 435 if ptr.len() < read_len { 436 bail!(Error::Input); 437 } 438 let mut ciphertext_buf = [0; MAXKEMCTLEN]; 439 let ciphertext = &mut ciphertext_buf[..kem.ciphertext_len()]; 440 self.symmetricstate 441 .decrypt_and_mix_hash(&ptr[..read_len], ciphertext) 442 .map_err(|_| Error::Decrypt)?; 443 let mut kem_output_buf = [0; MAXKEMSSLEN]; 444 let kem_output = &mut kem_output_buf[..kem.shared_secret_len()]; 445 kem.decapsulate(ciphertext, kem_output).map_err(|_| Error::Kem)?; 446 self.symmetricstate.mix_key(&kem_output[..kem.shared_secret_len()]); 447 ptr = &ptr[read_len..]; 448 }, 449 } 450 } 451 452 self.symmetricstate.decrypt_and_mix_hash(ptr, payload).map_err(|_| Error::Decrypt)?; 453 if last { 454 self.symmetricstate.split(&mut self.cipherstates.0, &mut self.cipherstates.1); 455 } 456 let payload_len = 457 if self.symmetricstate.has_key() { ptr.len() - TAGLEN } else { ptr.len() }; 458 Ok(payload_len) 459 } 460 461 /// Set the preshared key at the specified location. It is up to the caller 462 /// to correctly set the location based on the specified handshake - Snow 463 /// won't stop you from placing a PSK in an unused slot. 464 /// 465 /// # Errors 466 /// 467 /// Will result in `Error::Input` if the PSK is not the right length or the location is out of bounds. set_psk(&mut self, location: usize, key: &[u8]) -> Result<(), Error>468 pub fn set_psk(&mut self, location: usize, key: &[u8]) -> Result<(), Error> { 469 if key.len() != PSKLEN || self.psks.len() <= location { 470 bail!(Error::Input); 471 } 472 473 let mut new_psk = [0u8; PSKLEN]; 474 new_psk.copy_from_slice(key); 475 self.psks[location as usize] = Some(new_psk); 476 477 Ok(()) 478 } 479 480 /// Get the remote party's static public key, if available. 481 /// 482 /// Note: will return `None` if either the chosen Noise pattern 483 /// doesn't necessitate a remote static key, *or* if the remote 484 /// static key is not yet known (as can be the case in the `XX` 485 /// pattern, for example). get_remote_static(&self) -> Option<&[u8]>486 pub fn get_remote_static(&self) -> Option<&[u8]> { 487 self.rs.get().map(|rs| &rs[..self.dh_len()]) 488 } 489 490 /// Get the handshake hash. 491 /// 492 /// Returns a slice of length `Hasher.hash_len()` (i.e. HASHLEN for the chosen Hash function). get_handshake_hash(&self) -> &[u8]493 pub fn get_handshake_hash(&self) -> &[u8] { 494 self.symmetricstate.handshake_hash() 495 } 496 497 /// Check if this session was started with the "initiator" role. is_initiator(&self) -> bool498 pub fn is_initiator(&self) -> bool { 499 self.initiator 500 } 501 502 /// Check if the handshake is finished and `into_transport_mode()` can now be called. is_handshake_finished(&self) -> bool503 pub fn is_handshake_finished(&self) -> bool { 504 self.pattern_position == self.message_patterns.len() 505 } 506 507 /// Check whether it is our turn to send in the handshake state machine is_my_turn(&self) -> bool508 pub fn is_my_turn(&self) -> bool { 509 self.my_turn 510 } 511 512 /// Perform the split calculation and return the resulting keys. 513 /// 514 /// This returns raw key material so it should be used with care. The "risky-raw-split" 515 /// feature has to be enabled to use this function. 516 #[cfg(feature = "risky-raw-split")] dangerously_get_raw_split(&mut self) -> ([u8; CIPHERKEYLEN], [u8; CIPHERKEYLEN])517 pub fn dangerously_get_raw_split(&mut self) -> ([u8; CIPHERKEYLEN], [u8; CIPHERKEYLEN]) { 518 let mut output = ([0u8; MAXHASHLEN], [0u8; MAXHASHLEN]); 519 self.symmetricstate.split_raw(&mut output.0, &mut output.1); 520 (output.0[..CIPHERKEYLEN].try_into().unwrap(), output.1[..CIPHERKEYLEN].try_into().unwrap()) 521 } 522 523 /// Convert this `HandshakeState` into a `TransportState` with an internally stored nonce. into_transport_mode(self) -> Result<TransportState, Error>524 pub fn into_transport_mode(self) -> Result<TransportState, Error> { 525 self.try_into() 526 } 527 528 /// Convert this `HandshakeState` into a `StatelessTransportState` without an internally stored nonce. into_stateless_transport_mode(self) -> Result<StatelessTransportState, Error>529 pub fn into_stateless_transport_mode(self) -> Result<StatelessTransportState, Error> { 530 self.try_into() 531 } 532 } 533 534 impl fmt::Debug for HandshakeState { fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result535 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 536 fmt.debug_struct("HandshakeState").finish() 537 } 538 } 539