1 #[cfg(feature = "risky-raw-split")]
2 use crate::constants::{CIPHERKEYLEN, MAXHASHLEN};
3 #[cfg(feature = "hfs")]
4 use crate::constants::{MAXKEMCTLEN, MAXKEMPUBLEN, MAXKEMSSLEN};
5 #[cfg(feature = "hfs")]
6 use crate::types::Kem;
7 use crate::{
8     cipherstate::{CipherState, CipherStates},
9     constants::{MAXDHLEN, MAXMSGLEN, PSKLEN, TAGLEN},
10     error::{Error, InitStage, StateProblem},
11     params::{DhToken, HandshakeTokens, MessagePatterns, NoiseParams, Token},
12     stateless_transportstate::StatelessTransportState,
13     symmetricstate::SymmetricState,
14     transportstate::TransportState,
15     types::{Dh, Hash, Random},
16     utils::Toggle,
17 };
18 use std::{
19     convert::{TryFrom, TryInto},
20     fmt,
21 };
22 
23 /// A state machine encompassing the handshake phase of a Noise session.
24 ///
25 /// **Note:** you are probably looking for [`Builder`](struct.Builder.html) to
26 /// get started.
27 ///
28 /// See: [http://noiseprotocol.org/noise.html#the-handshakestate-object](http://noiseprotocol.org/noise.html#the-handshakestate-object)
29 pub struct HandshakeState {
30     pub(crate) rng:              Box<dyn Random>,
31     pub(crate) symmetricstate:   SymmetricState,
32     pub(crate) cipherstates:     CipherStates,
33     pub(crate) s:                Toggle<Box<dyn Dh>>,
34     pub(crate) e:                Toggle<Box<dyn Dh>>,
35     pub(crate) fixed_ephemeral:  bool,
36     pub(crate) rs:               Toggle<[u8; MAXDHLEN]>,
37     pub(crate) re:               Toggle<[u8; MAXDHLEN]>,
38     pub(crate) initiator:        bool,
39     pub(crate) params:           NoiseParams,
40     pub(crate) psks:             [Option<[u8; PSKLEN]>; 10],
41     #[cfg(feature = "hfs")]
42     pub(crate) kem:              Option<Box<dyn Kem>>,
43     #[cfg(feature = "hfs")]
44     pub(crate) kem_re:           Option<[u8; MAXKEMPUBLEN]>,
45     pub(crate) my_turn:          bool,
46     pub(crate) message_patterns: MessagePatterns,
47     pub(crate) pattern_position: usize,
48 }
49 
50 impl HandshakeState {
51     #[allow(clippy::too_many_arguments)]
new( rng: Box<dyn Random>, cipherstate: CipherState, hasher: Box<dyn Hash>, s: Toggle<Box<dyn Dh>>, e: Toggle<Box<dyn Dh>>, fixed_ephemeral: bool, rs: Toggle<[u8; MAXDHLEN]>, re: Toggle<[u8; MAXDHLEN]>, initiator: bool, params: NoiseParams, psks: [Option<[u8; PSKLEN]>; 10], prologue: &[u8], cipherstates: CipherStates, ) -> Result<HandshakeState, Error>52     pub(crate) fn new(
53         rng: Box<dyn Random>,
54         cipherstate: CipherState,
55         hasher: Box<dyn Hash>,
56         s: Toggle<Box<dyn Dh>>,
57         e: Toggle<Box<dyn Dh>>,
58         fixed_ephemeral: bool,
59         rs: Toggle<[u8; MAXDHLEN]>,
60         re: Toggle<[u8; MAXDHLEN]>,
61         initiator: bool,
62         params: NoiseParams,
63         psks: [Option<[u8; PSKLEN]>; 10],
64         prologue: &[u8],
65         cipherstates: CipherStates,
66     ) -> Result<HandshakeState, Error> {
67         if (s.is_on() && e.is_on() && s.pub_len() != e.pub_len())
68             || (s.is_on() && rs.is_on() && s.pub_len() > rs.len())
69             || (s.is_on() && re.is_on() && s.pub_len() > re.len())
70         {
71             bail!(InitStage::ValidateKeyLengths);
72         }
73 
74         let tokens = HandshakeTokens::try_from(&params.handshake)?;
75 
76         let mut symmetricstate = SymmetricState::new(cipherstate, hasher);
77 
78         symmetricstate.initialize(&params.name);
79         symmetricstate.mix_hash(prologue);
80 
81         let dh_len = s.pub_len();
82         if initiator {
83             for token in tokens.premsg_pattern_i {
84                 symmetricstate.mix_hash(
85                     match *token {
86                         Token::S => &s,
87                         Token::E => &e,
88                         _ => unreachable!(),
89                     }
90                     .get()
91                     .ok_or(StateProblem::MissingKeyMaterial)?
92                     .pubkey(),
93                 );
94             }
95             for token in tokens.premsg_pattern_r {
96                 symmetricstate.mix_hash(
97                     &match *token {
98                         Token::S => &rs,
99                         Token::E => &re,
100                         _ => unreachable!(),
101                     }
102                     .get()
103                     .ok_or(StateProblem::MissingKeyMaterial)?[..dh_len],
104                 );
105             }
106         } else {
107             for token in tokens.premsg_pattern_i {
108                 symmetricstate.mix_hash(
109                     &match *token {
110                         Token::S => &rs,
111                         Token::E => &re,
112                         _ => unreachable!(),
113                     }
114                     .get()
115                     .ok_or(StateProblem::MissingKeyMaterial)?[..dh_len],
116                 );
117             }
118             for token in tokens.premsg_pattern_r {
119                 symmetricstate.mix_hash(
120                     match *token {
121                         Token::S => &s,
122                         Token::E => &e,
123                         _ => unreachable!(),
124                     }
125                     .get()
126                     .ok_or(StateProblem::MissingKeyMaterial)?
127                     .pubkey(),
128                 );
129             }
130         }
131 
132         Ok(HandshakeState {
133             rng,
134             symmetricstate,
135             cipherstates,
136             s,
137             e,
138             fixed_ephemeral,
139             rs,
140             re,
141             initiator,
142             params,
143             psks,
144             #[cfg(feature = "hfs")]
145             kem: None,
146             #[cfg(feature = "hfs")]
147             kem_re: None,
148             my_turn: initiator,
149             message_patterns: tokens.msg_patterns,
150             pattern_position: 0,
151         })
152     }
153 
dh_len(&self) -> usize154     pub(crate) fn dh_len(&self) -> usize {
155         self.s.pub_len()
156     }
157 
158     #[cfg(feature = "hfs")]
set_kem(&mut self, kem: Box<dyn Kem>)159     pub(crate) fn set_kem(&mut self, kem: Box<dyn Kem>) {
160         self.kem = Some(kem);
161     }
162 
dh(&self, token: &DhToken) -> Result<[u8; MAXDHLEN], Error>163     fn dh(&self, token: &DhToken) -> Result<[u8; MAXDHLEN], Error> {
164         let mut dh_out = [0u8; MAXDHLEN];
165         let (dh, key) = match (token, self.is_initiator()) {
166             (DhToken::Ee, _) => (&self.e, &self.re),
167             (DhToken::Ss, _) => (&self.s, &self.rs),
168             (DhToken::Se, true) | (DhToken::Es, false) => (&self.s, &self.re),
169             (DhToken::Es, true) | (DhToken::Se, false) => (&self.e, &self.rs),
170         };
171         if !(dh.is_on() && key.is_on()) {
172             bail!(StateProblem::MissingKeyMaterial);
173         }
174         dh.dh(&**key, &mut dh_out).map_err(|_| Error::Dh)?;
175         Ok(dh_out)
176     }
177 
178     /// This method will return `true` if the *previous* write payload was encrypted.
179     ///
180     /// See [Payload Security Properties](http://noiseprotocol.org/noise.html#payload-security-properties)
181     /// for more information on the specific properties of your chosen handshake pattern.
182     ///
183     /// # Examples
184     ///
185     /// ```rust,ignore
186     /// let mut session = Builder::new("Noise_NN_25519_AESGCM_SHA256".parse()?)
187     ///     .build_initiator()?;
188     ///
189     /// // write message...
190     ///
191     /// assert!(session.was_write_payload_encrypted());
192     /// ```
was_write_payload_encrypted(&self) -> bool193     pub fn was_write_payload_encrypted(&self) -> bool {
194         self.symmetricstate.has_key()
195     }
196 
197     /// Construct a message from `payload` (and pending handshake tokens if in handshake state),
198     /// and writes it to the `message` buffer.
199     ///
200     /// Returns the size of the written payload.
201     ///
202     /// # Errors
203     ///
204     /// Will result in `Error::Input` if the size of the output exceeds the max message
205     /// length in the Noise Protocol (65535 bytes).
write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, Error>206     pub fn write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, Error> {
207         let checkpoint = self.symmetricstate.checkpoint();
208         match self._write_message(payload, message) {
209             Ok(res) => {
210                 self.pattern_position += 1;
211                 self.my_turn = false;
212                 Ok(res)
213             },
214             Err(err) => {
215                 self.symmetricstate.restore(checkpoint);
216                 Err(err)
217             },
218         }
219     }
220 
_write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, Error>221     fn _write_message(&mut self, payload: &[u8], message: &mut [u8]) -> Result<usize, Error> {
222         if !self.my_turn {
223             bail!(StateProblem::NotTurnToWrite);
224         } else if self.pattern_position >= self.message_patterns.len() {
225             bail!(StateProblem::HandshakeAlreadyFinished);
226         }
227 
228         let mut byte_index = 0;
229         for token in self.message_patterns[self.pattern_position].iter() {
230             match token {
231                 Token::E => {
232                     if byte_index + self.e.pub_len() > message.len() {
233                         bail!(Error::Input)
234                     }
235 
236                     if !self.fixed_ephemeral {
237                         self.e.generate(&mut *self.rng);
238                     }
239                     let pubkey = self.e.pubkey();
240                     message[byte_index..byte_index + pubkey.len()].copy_from_slice(pubkey);
241                     byte_index += pubkey.len();
242                     self.symmetricstate.mix_hash(pubkey);
243                     if self.params.handshake.is_psk() {
244                         self.symmetricstate.mix_key(pubkey);
245                     }
246                     self.e.enable();
247                 },
248                 Token::S => {
249                     if !self.s.is_on() {
250                         bail!(StateProblem::MissingKeyMaterial);
251                     } else if byte_index + self.s.pub_len() > message.len() {
252                         bail!(Error::Input)
253                     }
254 
255                     byte_index += self
256                         .symmetricstate
257                         .encrypt_and_mix_hash(self.s.pubkey(), &mut message[byte_index..])?;
258                 },
259                 Token::Psk(n) => match self.psks[*n as usize] {
260                     Some(psk) => {
261                         self.symmetricstate.mix_key_and_hash(&psk);
262                     },
263                     None => {
264                         bail!(StateProblem::MissingPsk);
265                     },
266                 },
267                 Token::Dh(t) => {
268                     let dh_out = self.dh(t)?;
269                     self.symmetricstate.mix_key(&dh_out[..self.dh_len()]);
270                 },
271                 #[cfg(feature = "hfs")]
272                 Token::E1 => {
273                     let kem = self.kem.as_mut().ok_or(Error::Input)?;
274                     if kem.pub_len() > message.len() {
275                         bail!(Error::Input);
276                     }
277 
278                     kem.generate(&mut *self.rng);
279                     byte_index += self
280                         .symmetricstate
281                         .encrypt_and_mix_hash(kem.pubkey(), &mut message[byte_index..])?;
282                 },
283                 #[cfg(feature = "hfs")]
284                 Token::Ekem1 => {
285                     let kem = self.kem.as_mut().unwrap();
286                     let mut kem_output_buf = [0; MAXKEMSSLEN];
287                     let mut ciphertext_buf = [0; MAXKEMCTLEN];
288 
289                     if kem.ciphertext_len() > message.len() {
290                         bail!(Error::Input);
291                     }
292 
293                     let kem_output = &mut kem_output_buf[..kem.shared_secret_len()];
294                     let ciphertext = &mut ciphertext_buf[..kem.ciphertext_len()];
295                     let pubkey = &self.kem_re.as_ref().unwrap()[..kem.pub_len()];
296                     if kem.encapsulate(pubkey, kem_output, ciphertext).is_err() {
297                         bail!(Error::Kem);
298                     }
299 
300                     byte_index += self.symmetricstate.encrypt_and_mix_hash(
301                         &ciphertext[..kem.ciphertext_len()],
302                         &mut message[byte_index..],
303                     )?;
304                     self.symmetricstate.mix_key(&kem_output[..kem.shared_secret_len()]);
305                 },
306             }
307         }
308 
309         if byte_index + payload.len() + TAGLEN > message.len() {
310             bail!(Error::Input);
311         }
312         byte_index +=
313             self.symmetricstate.encrypt_and_mix_hash(payload, &mut message[byte_index..])?;
314         if byte_index > MAXMSGLEN {
315             bail!(Error::Input);
316         }
317         if self.pattern_position == (self.message_patterns.len() - 1) {
318             self.symmetricstate.split(&mut self.cipherstates.0, &mut self.cipherstates.1);
319         }
320         Ok(byte_index)
321     }
322 
323     /// Reads a noise message from `input`
324     ///
325     /// Returns the size of the payload written to `payload`.
326     ///
327     /// # Errors
328     ///
329     /// Will result in `Error::Decrypt` if the contents couldn't be decrypted and/or the
330     /// authentication tag didn't verify.
331     ///
332     /// # Panics
333     ///
334     /// This function will panic if there is no key, or if there is a nonce overflow.
read_message(&mut self, message: &[u8], payload: &mut [u8]) -> Result<usize, Error>335     pub fn read_message(&mut self, message: &[u8], payload: &mut [u8]) -> Result<usize, Error> {
336         let checkpoint = self.symmetricstate.checkpoint();
337         match self._read_message(message, payload) {
338             Ok(res) => {
339                 self.pattern_position += 1;
340                 self.my_turn = true;
341                 Ok(res)
342             },
343             Err(err) => {
344                 self.symmetricstate.restore(checkpoint);
345                 Err(err)
346             },
347         }
348     }
349 
_read_message(&mut self, message: &[u8], payload: &mut [u8]) -> Result<usize, Error>350     fn _read_message(&mut self, message: &[u8], payload: &mut [u8]) -> Result<usize, Error> {
351         if message.len() > MAXMSGLEN {
352             bail!(Error::Input);
353         } else if self.my_turn {
354             bail!(StateProblem::NotTurnToRead);
355         } else if self.pattern_position >= self.message_patterns.len() {
356             bail!(StateProblem::HandshakeAlreadyFinished);
357         }
358         let last = self.pattern_position == (self.message_patterns.len() - 1);
359 
360         let dh_len = self.dh_len();
361         let mut ptr = message;
362         for token in self.message_patterns[self.pattern_position].iter() {
363             match token {
364                 Token::E => {
365                     if ptr.len() < dh_len {
366                         bail!(Error::Input);
367                     }
368                     self.re[..dh_len].copy_from_slice(&ptr[..dh_len]);
369                     ptr = &ptr[dh_len..];
370                     self.symmetricstate.mix_hash(&self.re[..dh_len]);
371                     if self.params.handshake.is_psk() {
372                         self.symmetricstate.mix_key(&self.re[..dh_len]);
373                     }
374                     self.re.enable();
375                 },
376                 Token::S => {
377                     let data = if self.symmetricstate.has_key() {
378                         if ptr.len() < dh_len + TAGLEN {
379                             bail!(Error::Input);
380                         }
381                         let temp = &ptr[..dh_len + TAGLEN];
382                         ptr = &ptr[dh_len + TAGLEN..];
383                         temp
384                     } else {
385                         if ptr.len() < dh_len {
386                             bail!(Error::Input);
387                         }
388                         let temp = &ptr[..dh_len];
389                         ptr = &ptr[dh_len..];
390                         temp
391                     };
392                     self.symmetricstate
393                         .decrypt_and_mix_hash(data, &mut self.rs[..dh_len])
394                         .map_err(|_| Error::Decrypt)?;
395                     self.rs.enable();
396                 },
397                 Token::Psk(n) => match self.psks[*n as usize] {
398                     Some(psk) => {
399                         self.symmetricstate.mix_key_and_hash(&psk);
400                     },
401                     None => {
402                         bail!(StateProblem::MissingPsk);
403                     },
404                 },
405                 Token::Dh(t) => {
406                     let dh_out = self.dh(t)?;
407                     self.symmetricstate.mix_key(&dh_out[..self.dh_len()]);
408                 },
409                 #[cfg(feature = "hfs")]
410                 Token::E1 => {
411                     let kem = self.kem.as_ref().ok_or(Error::Kem)?;
412                     let read_len = if self.symmetricstate.has_key() {
413                         kem.pub_len() + TAGLEN
414                     } else {
415                         kem.pub_len()
416                     };
417                     if ptr.len() < read_len {
418                         bail!(Error::Input);
419                     }
420                     let mut kem_re = [0; MAXKEMPUBLEN];
421                     self.symmetricstate
422                         .decrypt_and_mix_hash(&ptr[..read_len], &mut kem_re[..kem.pub_len()])
423                         .map_err(|_| Error::Decrypt)?;
424                     self.kem_re = Some(kem_re);
425                     ptr = &ptr[read_len..];
426                 },
427                 #[cfg(feature = "hfs")]
428                 Token::Ekem1 => {
429                     let kem = self.kem.as_ref().unwrap();
430                     let read_len = if self.symmetricstate.has_key() {
431                         kem.ciphertext_len() + TAGLEN
432                     } else {
433                         kem.ciphertext_len()
434                     };
435                     if ptr.len() < read_len {
436                         bail!(Error::Input);
437                     }
438                     let mut ciphertext_buf = [0; MAXKEMCTLEN];
439                     let ciphertext = &mut ciphertext_buf[..kem.ciphertext_len()];
440                     self.symmetricstate
441                         .decrypt_and_mix_hash(&ptr[..read_len], ciphertext)
442                         .map_err(|_| Error::Decrypt)?;
443                     let mut kem_output_buf = [0; MAXKEMSSLEN];
444                     let kem_output = &mut kem_output_buf[..kem.shared_secret_len()];
445                     kem.decapsulate(ciphertext, kem_output).map_err(|_| Error::Kem)?;
446                     self.symmetricstate.mix_key(&kem_output[..kem.shared_secret_len()]);
447                     ptr = &ptr[read_len..];
448                 },
449             }
450         }
451 
452         self.symmetricstate.decrypt_and_mix_hash(ptr, payload).map_err(|_| Error::Decrypt)?;
453         if last {
454             self.symmetricstate.split(&mut self.cipherstates.0, &mut self.cipherstates.1);
455         }
456         let payload_len =
457             if self.symmetricstate.has_key() { ptr.len() - TAGLEN } else { ptr.len() };
458         Ok(payload_len)
459     }
460 
461     /// Set the preshared key at the specified location. It is up to the caller
462     /// to correctly set the location based on the specified handshake - Snow
463     /// won't stop you from placing a PSK in an unused slot.
464     ///
465     /// # Errors
466     ///
467     /// Will result in `Error::Input` if the PSK is not the right length or the location is out of bounds.
set_psk(&mut self, location: usize, key: &[u8]) -> Result<(), Error>468     pub fn set_psk(&mut self, location: usize, key: &[u8]) -> Result<(), Error> {
469         if key.len() != PSKLEN || self.psks.len() <= location {
470             bail!(Error::Input);
471         }
472 
473         let mut new_psk = [0u8; PSKLEN];
474         new_psk.copy_from_slice(key);
475         self.psks[location as usize] = Some(new_psk);
476 
477         Ok(())
478     }
479 
480     /// Get the remote party's static public key, if available.
481     ///
482     /// Note: will return `None` if either the chosen Noise pattern
483     /// doesn't necessitate a remote static key, *or* if the remote
484     /// static key is not yet known (as can be the case in the `XX`
485     /// pattern, for example).
get_remote_static(&self) -> Option<&[u8]>486     pub fn get_remote_static(&self) -> Option<&[u8]> {
487         self.rs.get().map(|rs| &rs[..self.dh_len()])
488     }
489 
490     /// Get the handshake hash.
491     ///
492     /// Returns a slice of length `Hasher.hash_len()` (i.e. HASHLEN for the chosen Hash function).
get_handshake_hash(&self) -> &[u8]493     pub fn get_handshake_hash(&self) -> &[u8] {
494         self.symmetricstate.handshake_hash()
495     }
496 
497     /// Check if this session was started with the "initiator" role.
is_initiator(&self) -> bool498     pub fn is_initiator(&self) -> bool {
499         self.initiator
500     }
501 
502     /// Check if the handshake is finished and `into_transport_mode()` can now be called.
is_handshake_finished(&self) -> bool503     pub fn is_handshake_finished(&self) -> bool {
504         self.pattern_position == self.message_patterns.len()
505     }
506 
507     /// Check whether it is our turn to send in the handshake state machine
is_my_turn(&self) -> bool508     pub fn is_my_turn(&self) -> bool {
509         self.my_turn
510     }
511 
512     /// Perform the split calculation and return the resulting keys.
513     ///
514     /// This returns raw key material so it should be used with care. The "risky-raw-split"
515     /// feature has to be enabled to use this function.
516     #[cfg(feature = "risky-raw-split")]
dangerously_get_raw_split(&mut self) -> ([u8; CIPHERKEYLEN], [u8; CIPHERKEYLEN])517     pub fn dangerously_get_raw_split(&mut self) -> ([u8; CIPHERKEYLEN], [u8; CIPHERKEYLEN]) {
518         let mut output = ([0u8; MAXHASHLEN], [0u8; MAXHASHLEN]);
519         self.symmetricstate.split_raw(&mut output.0, &mut output.1);
520         (output.0[..CIPHERKEYLEN].try_into().unwrap(), output.1[..CIPHERKEYLEN].try_into().unwrap())
521     }
522 
523     /// Convert this `HandshakeState` into a `TransportState` with an internally stored nonce.
into_transport_mode(self) -> Result<TransportState, Error>524     pub fn into_transport_mode(self) -> Result<TransportState, Error> {
525         self.try_into()
526     }
527 
528     /// Convert this `HandshakeState` into a `StatelessTransportState` without an internally stored nonce.
into_stateless_transport_mode(self) -> Result<StatelessTransportState, Error>529     pub fn into_stateless_transport_mode(self) -> Result<StatelessTransportState, Error> {
530         self.try_into()
531     }
532 }
533 
534 impl fmt::Debug for HandshakeState {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result535     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
536         fmt.debug_struct("HandshakeState").finish()
537     }
538 }
539