1 //
2 // Copyright 2020-2021 Signal Messenger, LLC.
3 // SPDX-License-Identifier: AGPL-3.0-only
4 //
5 
6 use crate::{
7     CiphertextMessage, Context, Direction, IdentityKeyStore, KeyPair, PreKeySignalMessage,
8     PreKeyStore, ProtocolAddress, PublicKey, Result, SessionRecord, SessionStore, SignalMessage,
9     SignalProtocolError, SignedPreKeyStore,
10 };
11 
12 use crate::consts::MAX_FORWARD_JUMPS;
13 use crate::crypto;
14 use crate::ratchet::{ChainKey, MessageKeys};
15 use crate::session;
16 use crate::state::SessionState;
17 
18 use rand::{CryptoRng, Rng};
19 
message_encrypt( ptext: &[u8], remote_address: &ProtocolAddress, session_store: &mut dyn SessionStore, identity_store: &mut dyn IdentityKeyStore, ctx: Context, ) -> Result<CiphertextMessage>20 pub async fn message_encrypt(
21     ptext: &[u8],
22     remote_address: &ProtocolAddress,
23     session_store: &mut dyn SessionStore,
24     identity_store: &mut dyn IdentityKeyStore,
25     ctx: Context,
26 ) -> Result<CiphertextMessage> {
27     let mut session_record = session_store
28         .load_session(remote_address, ctx)
29         .await?
30         .ok_or_else(|| SignalProtocolError::SessionNotFound(format!("{}", remote_address)))?;
31     let session_state = session_record.session_state_mut()?;
32 
33     let chain_key = session_state.get_sender_chain_key()?;
34 
35     let message_keys = chain_key.message_keys()?;
36 
37     let sender_ephemeral = session_state.sender_ratchet_key()?;
38     let previous_counter = session_state.previous_counter()?;
39     let session_version = session_state.session_version()? as u8;
40 
41     let local_identity_key = session_state.local_identity_key()?;
42     let their_identity_key = session_state
43         .remote_identity_key()?
44         .ok_or(SignalProtocolError::InvalidSessionStructure)?;
45 
46     let ctext = crypto::aes_256_cbc_encrypt(ptext, message_keys.cipher_key(), message_keys.iv())?;
47 
48     let message = if let Some(items) = session_state.unacknowledged_pre_key_message_items()? {
49         let local_registration_id = session_state.local_registration_id()?;
50 
51         log::info!(
52             "Building PreKeyWhisperMessage for: {} with preKeyId: {}",
53             remote_address,
54             items
55                 .pre_key_id()?
56                 .map_or_else(|| "<none>".to_string(), |id| id.to_string())
57         );
58 
59         let message = SignalMessage::new(
60             session_version,
61             message_keys.mac_key(),
62             sender_ephemeral,
63             chain_key.index(),
64             previous_counter,
65             &ctext,
66             &local_identity_key,
67             &their_identity_key,
68         )?;
69 
70         CiphertextMessage::PreKeySignalMessage(PreKeySignalMessage::new(
71             session_version,
72             local_registration_id,
73             items.pre_key_id()?,
74             items.signed_pre_key_id()?,
75             *items.base_key()?,
76             local_identity_key,
77             message,
78         )?)
79     } else {
80         CiphertextMessage::SignalMessage(SignalMessage::new(
81             session_version,
82             message_keys.mac_key(),
83             sender_ephemeral,
84             chain_key.index(),
85             previous_counter,
86             &ctext,
87             &local_identity_key,
88             &their_identity_key,
89         )?)
90     };
91 
92     session_state.set_sender_chain_key(&chain_key.next_chain_key()?)?;
93 
94     // XXX why is this check after everything else?!!
95     if !identity_store
96         .is_trusted_identity(remote_address, &their_identity_key, Direction::Sending, ctx)
97         .await?
98     {
99         log::warn!(
100             "Identity key {} is not trusted for remote address {}",
101             their_identity_key
102                 .public_key()
103                 .public_key_bytes()
104                 .map_or_else(|e| format!("<error: {}>", e), hex::encode),
105             remote_address,
106         );
107         return Err(SignalProtocolError::UntrustedIdentity(
108             remote_address.clone(),
109         ));
110     }
111 
112     // XXX this could be combined with the above call to the identity store (in a new API)
113     identity_store
114         .save_identity(remote_address, &their_identity_key, ctx)
115         .await?;
116 
117     session_store
118         .store_session(remote_address, &session_record, ctx)
119         .await?;
120     Ok(message)
121 }
122 
message_decrypt<R: Rng + CryptoRng>( ciphertext: &CiphertextMessage, remote_address: &ProtocolAddress, session_store: &mut dyn SessionStore, identity_store: &mut dyn IdentityKeyStore, pre_key_store: &mut dyn PreKeyStore, signed_pre_key_store: &mut dyn SignedPreKeyStore, csprng: &mut R, ctx: Context, ) -> Result<Vec<u8>>123 pub async fn message_decrypt<R: Rng + CryptoRng>(
124     ciphertext: &CiphertextMessage,
125     remote_address: &ProtocolAddress,
126     session_store: &mut dyn SessionStore,
127     identity_store: &mut dyn IdentityKeyStore,
128     pre_key_store: &mut dyn PreKeyStore,
129     signed_pre_key_store: &mut dyn SignedPreKeyStore,
130     csprng: &mut R,
131     ctx: Context,
132 ) -> Result<Vec<u8>> {
133     match ciphertext {
134         CiphertextMessage::SignalMessage(m) => {
135             message_decrypt_signal(
136                 m,
137                 remote_address,
138                 session_store,
139                 identity_store,
140                 csprng,
141                 ctx,
142             )
143             .await
144         }
145         CiphertextMessage::PreKeySignalMessage(m) => {
146             message_decrypt_prekey(
147                 m,
148                 remote_address,
149                 session_store,
150                 identity_store,
151                 pre_key_store,
152                 signed_pre_key_store,
153                 csprng,
154                 ctx,
155             )
156             .await
157         }
158         _ => Err(SignalProtocolError::InvalidArgument(
159             "SessionCipher::decrypt cannot decrypt this message type".to_owned(),
160         )),
161     }
162 }
163 
message_decrypt_prekey<R: Rng + CryptoRng>( ciphertext: &PreKeySignalMessage, remote_address: &ProtocolAddress, session_store: &mut dyn SessionStore, identity_store: &mut dyn IdentityKeyStore, pre_key_store: &mut dyn PreKeyStore, signed_pre_key_store: &mut dyn SignedPreKeyStore, csprng: &mut R, ctx: Context, ) -> Result<Vec<u8>>164 pub async fn message_decrypt_prekey<R: Rng + CryptoRng>(
165     ciphertext: &PreKeySignalMessage,
166     remote_address: &ProtocolAddress,
167     session_store: &mut dyn SessionStore,
168     identity_store: &mut dyn IdentityKeyStore,
169     pre_key_store: &mut dyn PreKeyStore,
170     signed_pre_key_store: &mut dyn SignedPreKeyStore,
171     csprng: &mut R,
172     ctx: Context,
173 ) -> Result<Vec<u8>> {
174     let mut session_record = session_store
175         .load_session(remote_address, ctx)
176         .await?
177         .unwrap_or_else(SessionRecord::new_fresh);
178 
179     // Make sure we log the session state if we fail to process the pre-key.
180     let pre_key_id_or_err = session::process_prekey(
181         ciphertext,
182         remote_address,
183         &mut session_record,
184         identity_store,
185         pre_key_store,
186         signed_pre_key_store,
187         ctx,
188     )
189     .await;
190 
191     let pre_key_id = match pre_key_id_or_err {
192         Ok(id) => id,
193         Err(e) => {
194             let errs = [e];
195             log::error!(
196                 "{}",
197                 create_decryption_failure_log(
198                     remote_address,
199                     &errs,
200                     &session_record,
201                     ciphertext.message()
202                 )?
203             );
204             let [e] = errs;
205             return Err(e);
206         }
207     };
208 
209     let ptext = decrypt_message_with_record(
210         remote_address,
211         &mut session_record,
212         ciphertext.message(),
213         csprng,
214     )?;
215 
216     session_store
217         .store_session(remote_address, &session_record, ctx)
218         .await?;
219 
220     if let Some(pre_key_id) = pre_key_id {
221         pre_key_store.remove_pre_key(pre_key_id, ctx).await?;
222     }
223 
224     Ok(ptext)
225 }
226 
message_decrypt_signal<R: Rng + CryptoRng>( ciphertext: &SignalMessage, remote_address: &ProtocolAddress, session_store: &mut dyn SessionStore, identity_store: &mut dyn IdentityKeyStore, csprng: &mut R, ctx: Context, ) -> Result<Vec<u8>>227 pub async fn message_decrypt_signal<R: Rng + CryptoRng>(
228     ciphertext: &SignalMessage,
229     remote_address: &ProtocolAddress,
230     session_store: &mut dyn SessionStore,
231     identity_store: &mut dyn IdentityKeyStore,
232     csprng: &mut R,
233     ctx: Context,
234 ) -> Result<Vec<u8>> {
235     let mut session_record = session_store
236         .load_session(remote_address, ctx)
237         .await?
238         .ok_or_else(|| SignalProtocolError::SessionNotFound(format!("{}", remote_address)))?;
239 
240     let ptext =
241         decrypt_message_with_record(remote_address, &mut session_record, ciphertext, csprng)?;
242 
243     // Why are we performing this check after decryption instead of before?
244     let their_identity_key = session_record
245         .session_state()?
246         .remote_identity_key()?
247         .ok_or(SignalProtocolError::InvalidSessionStructure)?;
248 
249     if !identity_store
250         .is_trusted_identity(
251             remote_address,
252             &their_identity_key,
253             Direction::Receiving,
254             ctx,
255         )
256         .await?
257     {
258         log::warn!(
259             "Identity key {} is not trusted for remote address {}",
260             their_identity_key
261                 .public_key()
262                 .public_key_bytes()
263                 .map_or_else(|e| format!("<error: {}>", e), hex::encode),
264             remote_address,
265         );
266         return Err(SignalProtocolError::UntrustedIdentity(
267             remote_address.clone(),
268         ));
269     }
270 
271     identity_store
272         .save_identity(remote_address, &their_identity_key, ctx)
273         .await?;
274 
275     session_store
276         .store_session(remote_address, &session_record, ctx)
277         .await?;
278 
279     Ok(ptext)
280 }
281 
create_decryption_failure_log( remote_address: &ProtocolAddress, mut errs: &[SignalProtocolError], record: &SessionRecord, ciphertext: &SignalMessage, ) -> Result<String>282 fn create_decryption_failure_log(
283     remote_address: &ProtocolAddress,
284     mut errs: &[SignalProtocolError],
285     record: &SessionRecord,
286     ciphertext: &SignalMessage,
287 ) -> Result<String> {
288     fn append_session_summary(
289         lines: &mut Vec<String>,
290         idx: usize,
291         state: Result<&SessionState>,
292         err: Option<&SignalProtocolError>,
293     ) {
294         let chains = state.and_then(|state| state.all_receiver_chain_logging_info());
295         match (err, &chains) {
296             (Some(err), Ok(chains)) => {
297                 lines.push(format!(
298                     "Candidate session {} failed with '{}', had {} receiver chains",
299                     idx,
300                     err,
301                     chains.len()
302                 ));
303             }
304             (Some(err), Err(state_err)) => {
305                 lines.push(format!(
306                     "Candidate session {} failed with '{}'; cannot get receiver chain info ({})",
307                     idx, err, state_err,
308                 ));
309             }
310             (None, Ok(chains)) => {
311                 lines.push(format!(
312                     "Candidate session {} had {} receiver chains",
313                     idx,
314                     chains.len()
315                 ));
316             }
317             (None, Err(state_err)) => {
318                 lines.push(format!(
319                     "Candidate session {}: cannot get receiver chain info ({})",
320                     idx, state_err,
321                 ));
322             }
323         }
324 
325         if let Ok(chains) = chains {
326             for chain in chains {
327                 let chain_idx = match chain.1 {
328                     Some(i) => i.to_string(),
329                     None => "missing in protobuf".to_string(),
330                 };
331 
332                 lines.push(format!(
333                     "Receiver chain with sender ratchet public key {} chain key index {}",
334                     hex::encode(chain.0),
335                     chain_idx
336                 ));
337             }
338         }
339     }
340 
341     let mut lines = vec![];
342 
343     lines.push(format!(
344         "Message from {} failed to decrypt; sender ratchet public key {} message counter {}",
345         remote_address,
346         hex::encode(ciphertext.sender_ratchet_key().public_key_bytes()?),
347         ciphertext.counter()
348     ));
349 
350     if let Ok(current_session) = record.session_state() {
351         let err = errs.first();
352         if err.is_some() {
353             errs = &errs[1..];
354         }
355         append_session_summary(&mut lines, 0, Ok(current_session), err);
356     } else {
357         lines.push("No current session".to_string());
358     }
359 
360     for (idx, (state, err)) in record
361         .previous_session_states()
362         .zip(errs.iter().map(Some).chain(std::iter::repeat(None)))
363         .enumerate()
364     {
365         let state = match state {
366             Ok(ref state) => Ok(state),
367             Err(err) => Err(err),
368         };
369         append_session_summary(&mut lines, idx + 1, state, err);
370     }
371 
372     Ok(lines.join("\n"))
373 }
374 
decrypt_message_with_record<R: Rng + CryptoRng>( remote_address: &ProtocolAddress, record: &mut SessionRecord, ciphertext: &SignalMessage, csprng: &mut R, ) -> Result<Vec<u8>>375 fn decrypt_message_with_record<R: Rng + CryptoRng>(
376     remote_address: &ProtocolAddress,
377     record: &mut SessionRecord,
378     ciphertext: &SignalMessage,
379     csprng: &mut R,
380 ) -> Result<Vec<u8>> {
381     let log_decryption_failure = |state: &SessionState, error: &SignalProtocolError| {
382         // A warning rather than an error because we try multiple sessions.
383         log::warn!(
384             "Failed to decrypt whisper message with ratchet key: {} and counter: {}. \
385              Session loaded for {}. Local session has base key: {} and counter: {}. {}",
386             ciphertext
387                 .sender_ratchet_key()
388                 .public_key_bytes()
389                 .map_or_else(|e| format!("<error: {}>", e), hex::encode),
390             ciphertext.counter(),
391             remote_address,
392             state
393                 .sender_ratchet_key_for_logging()
394                 .unwrap_or_else(|e| format!("<error: {}>", e)),
395             state.previous_counter().unwrap_or(u32::MAX),
396             error
397         );
398     };
399 
400     let mut errs = vec![];
401 
402     if let Ok(current_state) = record.session_state() {
403         let mut current_state = current_state.clone();
404         let result =
405             decrypt_message_with_state(&mut current_state, ciphertext, remote_address, csprng);
406 
407         match result {
408             Ok(ptext) => {
409                 log::info!(
410                     "decrypted message from {} with current session state (base key {})",
411                     remote_address,
412                     current_state
413                         .sender_ratchet_key_for_logging()
414                         .expect("successful decrypt always has a valid base key"),
415                 );
416                 record.set_session_state(current_state)?; // update the state
417                 return Ok(ptext);
418             }
419             Err(SignalProtocolError::DuplicatedMessage(_, _)) => {
420                 return result;
421             }
422             Err(e) => {
423                 log_decryption_failure(&current_state, &e);
424                 errs.push(e);
425             }
426         }
427     }
428 
429     // Try some old sessions:
430     let mut updated_session = None;
431 
432     for (idx, previous) in record.previous_session_states().enumerate() {
433         let mut previous = previous?;
434 
435         let result = decrypt_message_with_state(&mut previous, ciphertext, remote_address, csprng);
436 
437         match result {
438             Ok(ptext) => {
439                 log::info!(
440                     "decrypted message from {} with PREVIOUS session state (base key {})",
441                     remote_address,
442                     previous
443                         .sender_ratchet_key_for_logging()
444                         .expect("successful decrypt always has a valid base key"),
445                 );
446                 updated_session = Some((ptext, idx, previous));
447                 break;
448             }
449             Err(SignalProtocolError::DuplicatedMessage(_, _)) => {
450                 return result;
451             }
452             Err(e) => {
453                 log_decryption_failure(&previous, &e);
454                 errs.push(e);
455             }
456         }
457     }
458 
459     if let Some((ptext, idx, updated_session)) = updated_session {
460         record.promote_old_session(idx, updated_session)?;
461         Ok(ptext)
462     } else {
463         let previous_state_count = || record.previous_session_states().len();
464 
465         if let Ok(current_state) = record.session_state() {
466             log::error!(
467                 "No valid session for recipient: {}, current session base key {}, number of previous states: {}",
468                 remote_address,
469                 current_state.sender_ratchet_key_for_logging()
470                 .unwrap_or_else(|e| format!("<error: {}>", e)),
471                 previous_state_count(),
472             );
473         } else {
474             log::error!(
475                 "No valid session for recipient: {}, (no current session state), number of previous states: {}",
476                 remote_address,
477                 previous_state_count(),
478             );
479         }
480         log::error!(
481             "{}",
482             create_decryption_failure_log(remote_address, &errs, record, ciphertext)?
483         );
484         Err(SignalProtocolError::InvalidMessage(
485             "Message decryption failed",
486         ))
487     }
488 }
489 
decrypt_message_with_state<R: Rng + CryptoRng>( state: &mut SessionState, ciphertext: &SignalMessage, remote_address: &ProtocolAddress, csprng: &mut R, ) -> Result<Vec<u8>>490 fn decrypt_message_with_state<R: Rng + CryptoRng>(
491     state: &mut SessionState,
492     ciphertext: &SignalMessage,
493     remote_address: &ProtocolAddress,
494     csprng: &mut R,
495 ) -> Result<Vec<u8>> {
496     if !state.has_sender_chain()? {
497         return Err(SignalProtocolError::InvalidMessage(
498             "No session available to decrypt",
499         ));
500     }
501 
502     let ciphertext_version = ciphertext.message_version() as u32;
503     if ciphertext_version != state.session_version()? {
504         return Err(SignalProtocolError::UnrecognizedMessageVersion(
505             ciphertext_version,
506         ));
507     }
508 
509     let their_ephemeral = ciphertext.sender_ratchet_key();
510     let counter = ciphertext.counter();
511     let chain_key = get_or_create_chain_key(state, their_ephemeral, remote_address, csprng)?;
512     let message_keys =
513         get_or_create_message_key(state, their_ephemeral, remote_address, &chain_key, counter)?;
514 
515     let their_identity_key = state
516         .remote_identity_key()?
517         .ok_or(SignalProtocolError::InvalidSessionStructure)?;
518 
519     let mac_valid = ciphertext.verify_mac(
520         &their_identity_key,
521         &state.local_identity_key()?,
522         message_keys.mac_key(),
523     )?;
524 
525     if !mac_valid {
526         return Err(SignalProtocolError::InvalidCiphertext);
527     }
528 
529     let ptext = crypto::aes_256_cbc_decrypt(
530         ciphertext.body(),
531         message_keys.cipher_key(),
532         message_keys.iv(),
533     )?;
534 
535     state.clear_unacknowledged_pre_key_message()?;
536 
537     Ok(ptext)
538 }
539 
get_or_create_chain_key<R: Rng + CryptoRng>( state: &mut SessionState, their_ephemeral: &PublicKey, remote_address: &ProtocolAddress, csprng: &mut R, ) -> Result<ChainKey>540 fn get_or_create_chain_key<R: Rng + CryptoRng>(
541     state: &mut SessionState,
542     their_ephemeral: &PublicKey,
543     remote_address: &ProtocolAddress,
544     csprng: &mut R,
545 ) -> Result<ChainKey> {
546     if let Some(chain) = state.get_receiver_chain_key(their_ephemeral)? {
547         log::debug!("{} has existing receiver chain.", remote_address);
548         return Ok(chain);
549     }
550 
551     log::info!("{} creating new chains.", remote_address);
552 
553     let root_key = state.root_key()?;
554     let our_ephemeral = state.sender_ratchet_private_key()?;
555     let receiver_chain = root_key.create_chain(their_ephemeral, &our_ephemeral)?;
556     let our_new_ephemeral = KeyPair::generate(csprng);
557     let sender_chain = receiver_chain
558         .0
559         .create_chain(their_ephemeral, &our_new_ephemeral.private_key)?;
560 
561     state.set_root_key(&sender_chain.0)?;
562     state.add_receiver_chain(their_ephemeral, &receiver_chain.1)?;
563 
564     let current_index = state.get_sender_chain_key()?.index();
565     let previous_index = if current_index > 0 {
566         current_index - 1
567     } else {
568         0
569     };
570     state.set_previous_counter(previous_index)?;
571     state.set_sender_chain(&our_new_ephemeral, &sender_chain.1)?;
572 
573     Ok(receiver_chain.1)
574 }
575 
get_or_create_message_key( state: &mut SessionState, their_ephemeral: &PublicKey, remote_address: &ProtocolAddress, chain_key: &ChainKey, counter: u32, ) -> Result<MessageKeys>576 fn get_or_create_message_key(
577     state: &mut SessionState,
578     their_ephemeral: &PublicKey,
579     remote_address: &ProtocolAddress,
580     chain_key: &ChainKey,
581     counter: u32,
582 ) -> Result<MessageKeys> {
583     let chain_index = chain_key.index();
584 
585     if chain_index > counter {
586         return match state.get_message_keys(their_ephemeral, counter)? {
587             Some(keys) => Ok(keys),
588             None => {
589                 log::info!(
590                     "{} Duplicate message for counter: {}",
591                     remote_address,
592                     counter
593                 );
594                 Err(SignalProtocolError::DuplicatedMessage(chain_index, counter))
595             }
596         };
597     }
598 
599     assert!(chain_index <= counter);
600 
601     let jump = (counter - chain_index) as usize;
602 
603     if jump > MAX_FORWARD_JUMPS {
604         if state.session_with_self()? {
605             log::info!(
606                 "{} Jumping ahead {} messages (index: {}, counter: {})",
607                 remote_address,
608                 jump,
609                 chain_index,
610                 counter
611             );
612         } else {
613             log::error!(
614                 "{} Exceeded future message limit: {}, index: {}, counter: {})",
615                 remote_address,
616                 MAX_FORWARD_JUMPS,
617                 chain_index,
618                 counter
619             );
620             return Err(SignalProtocolError::InvalidMessage(
621                 "message from too far into the future",
622             ));
623         }
624     }
625 
626     let mut chain_key = chain_key.clone();
627 
628     while chain_key.index() < counter {
629         let message_keys = chain_key.message_keys()?;
630         state.set_message_keys(their_ephemeral, &message_keys)?;
631         chain_key = chain_key.next_chain_key()?;
632     }
633 
634     state.set_receiver_chain_key(their_ephemeral, &chain_key.next_chain_key()?)?;
635     chain_key.message_keys()
636 }
637