1 //
2 // Copyright 2020-2021 Signal Messenger, LLC.
3 // SPDX-License-Identifier: AGPL-3.0-only
4 //
5
6 use crate::{
7 CiphertextMessage, Context, Direction, IdentityKeyStore, KeyPair, PreKeySignalMessage,
8 PreKeyStore, ProtocolAddress, PublicKey, Result, SessionRecord, SessionStore, SignalMessage,
9 SignalProtocolError, SignedPreKeyStore,
10 };
11
12 use crate::consts::MAX_FORWARD_JUMPS;
13 use crate::crypto;
14 use crate::ratchet::{ChainKey, MessageKeys};
15 use crate::session;
16 use crate::state::SessionState;
17
18 use rand::{CryptoRng, Rng};
19
message_encrypt( ptext: &[u8], remote_address: &ProtocolAddress, session_store: &mut dyn SessionStore, identity_store: &mut dyn IdentityKeyStore, ctx: Context, ) -> Result<CiphertextMessage>20 pub async fn message_encrypt(
21 ptext: &[u8],
22 remote_address: &ProtocolAddress,
23 session_store: &mut dyn SessionStore,
24 identity_store: &mut dyn IdentityKeyStore,
25 ctx: Context,
26 ) -> Result<CiphertextMessage> {
27 let mut session_record = session_store
28 .load_session(remote_address, ctx)
29 .await?
30 .ok_or_else(|| SignalProtocolError::SessionNotFound(format!("{}", remote_address)))?;
31 let session_state = session_record.session_state_mut()?;
32
33 let chain_key = session_state.get_sender_chain_key()?;
34
35 let message_keys = chain_key.message_keys()?;
36
37 let sender_ephemeral = session_state.sender_ratchet_key()?;
38 let previous_counter = session_state.previous_counter()?;
39 let session_version = session_state.session_version()? as u8;
40
41 let local_identity_key = session_state.local_identity_key()?;
42 let their_identity_key = session_state
43 .remote_identity_key()?
44 .ok_or(SignalProtocolError::InvalidSessionStructure)?;
45
46 let ctext = crypto::aes_256_cbc_encrypt(ptext, message_keys.cipher_key(), message_keys.iv())?;
47
48 let message = if let Some(items) = session_state.unacknowledged_pre_key_message_items()? {
49 let local_registration_id = session_state.local_registration_id()?;
50
51 log::info!(
52 "Building PreKeyWhisperMessage for: {} with preKeyId: {}",
53 remote_address,
54 items
55 .pre_key_id()?
56 .map_or_else(|| "<none>".to_string(), |id| id.to_string())
57 );
58
59 let message = SignalMessage::new(
60 session_version,
61 message_keys.mac_key(),
62 sender_ephemeral,
63 chain_key.index(),
64 previous_counter,
65 &ctext,
66 &local_identity_key,
67 &their_identity_key,
68 )?;
69
70 CiphertextMessage::PreKeySignalMessage(PreKeySignalMessage::new(
71 session_version,
72 local_registration_id,
73 items.pre_key_id()?,
74 items.signed_pre_key_id()?,
75 *items.base_key()?,
76 local_identity_key,
77 message,
78 )?)
79 } else {
80 CiphertextMessage::SignalMessage(SignalMessage::new(
81 session_version,
82 message_keys.mac_key(),
83 sender_ephemeral,
84 chain_key.index(),
85 previous_counter,
86 &ctext,
87 &local_identity_key,
88 &their_identity_key,
89 )?)
90 };
91
92 session_state.set_sender_chain_key(&chain_key.next_chain_key()?)?;
93
94 // XXX why is this check after everything else?!!
95 if !identity_store
96 .is_trusted_identity(remote_address, &their_identity_key, Direction::Sending, ctx)
97 .await?
98 {
99 log::warn!(
100 "Identity key {} is not trusted for remote address {}",
101 their_identity_key
102 .public_key()
103 .public_key_bytes()
104 .map_or_else(|e| format!("<error: {}>", e), hex::encode),
105 remote_address,
106 );
107 return Err(SignalProtocolError::UntrustedIdentity(
108 remote_address.clone(),
109 ));
110 }
111
112 // XXX this could be combined with the above call to the identity store (in a new API)
113 identity_store
114 .save_identity(remote_address, &their_identity_key, ctx)
115 .await?;
116
117 session_store
118 .store_session(remote_address, &session_record, ctx)
119 .await?;
120 Ok(message)
121 }
122
message_decrypt<R: Rng + CryptoRng>( ciphertext: &CiphertextMessage, remote_address: &ProtocolAddress, session_store: &mut dyn SessionStore, identity_store: &mut dyn IdentityKeyStore, pre_key_store: &mut dyn PreKeyStore, signed_pre_key_store: &mut dyn SignedPreKeyStore, csprng: &mut R, ctx: Context, ) -> Result<Vec<u8>>123 pub async fn message_decrypt<R: Rng + CryptoRng>(
124 ciphertext: &CiphertextMessage,
125 remote_address: &ProtocolAddress,
126 session_store: &mut dyn SessionStore,
127 identity_store: &mut dyn IdentityKeyStore,
128 pre_key_store: &mut dyn PreKeyStore,
129 signed_pre_key_store: &mut dyn SignedPreKeyStore,
130 csprng: &mut R,
131 ctx: Context,
132 ) -> Result<Vec<u8>> {
133 match ciphertext {
134 CiphertextMessage::SignalMessage(m) => {
135 message_decrypt_signal(
136 m,
137 remote_address,
138 session_store,
139 identity_store,
140 csprng,
141 ctx,
142 )
143 .await
144 }
145 CiphertextMessage::PreKeySignalMessage(m) => {
146 message_decrypt_prekey(
147 m,
148 remote_address,
149 session_store,
150 identity_store,
151 pre_key_store,
152 signed_pre_key_store,
153 csprng,
154 ctx,
155 )
156 .await
157 }
158 _ => Err(SignalProtocolError::InvalidArgument(
159 "SessionCipher::decrypt cannot decrypt this message type".to_owned(),
160 )),
161 }
162 }
163
message_decrypt_prekey<R: Rng + CryptoRng>( ciphertext: &PreKeySignalMessage, remote_address: &ProtocolAddress, session_store: &mut dyn SessionStore, identity_store: &mut dyn IdentityKeyStore, pre_key_store: &mut dyn PreKeyStore, signed_pre_key_store: &mut dyn SignedPreKeyStore, csprng: &mut R, ctx: Context, ) -> Result<Vec<u8>>164 pub async fn message_decrypt_prekey<R: Rng + CryptoRng>(
165 ciphertext: &PreKeySignalMessage,
166 remote_address: &ProtocolAddress,
167 session_store: &mut dyn SessionStore,
168 identity_store: &mut dyn IdentityKeyStore,
169 pre_key_store: &mut dyn PreKeyStore,
170 signed_pre_key_store: &mut dyn SignedPreKeyStore,
171 csprng: &mut R,
172 ctx: Context,
173 ) -> Result<Vec<u8>> {
174 let mut session_record = session_store
175 .load_session(remote_address, ctx)
176 .await?
177 .unwrap_or_else(SessionRecord::new_fresh);
178
179 // Make sure we log the session state if we fail to process the pre-key.
180 let pre_key_id_or_err = session::process_prekey(
181 ciphertext,
182 remote_address,
183 &mut session_record,
184 identity_store,
185 pre_key_store,
186 signed_pre_key_store,
187 ctx,
188 )
189 .await;
190
191 let pre_key_id = match pre_key_id_or_err {
192 Ok(id) => id,
193 Err(e) => {
194 let errs = [e];
195 log::error!(
196 "{}",
197 create_decryption_failure_log(
198 remote_address,
199 &errs,
200 &session_record,
201 ciphertext.message()
202 )?
203 );
204 let [e] = errs;
205 return Err(e);
206 }
207 };
208
209 let ptext = decrypt_message_with_record(
210 remote_address,
211 &mut session_record,
212 ciphertext.message(),
213 csprng,
214 )?;
215
216 session_store
217 .store_session(remote_address, &session_record, ctx)
218 .await?;
219
220 if let Some(pre_key_id) = pre_key_id {
221 pre_key_store.remove_pre_key(pre_key_id, ctx).await?;
222 }
223
224 Ok(ptext)
225 }
226
message_decrypt_signal<R: Rng + CryptoRng>( ciphertext: &SignalMessage, remote_address: &ProtocolAddress, session_store: &mut dyn SessionStore, identity_store: &mut dyn IdentityKeyStore, csprng: &mut R, ctx: Context, ) -> Result<Vec<u8>>227 pub async fn message_decrypt_signal<R: Rng + CryptoRng>(
228 ciphertext: &SignalMessage,
229 remote_address: &ProtocolAddress,
230 session_store: &mut dyn SessionStore,
231 identity_store: &mut dyn IdentityKeyStore,
232 csprng: &mut R,
233 ctx: Context,
234 ) -> Result<Vec<u8>> {
235 let mut session_record = session_store
236 .load_session(remote_address, ctx)
237 .await?
238 .ok_or_else(|| SignalProtocolError::SessionNotFound(format!("{}", remote_address)))?;
239
240 let ptext =
241 decrypt_message_with_record(remote_address, &mut session_record, ciphertext, csprng)?;
242
243 // Why are we performing this check after decryption instead of before?
244 let their_identity_key = session_record
245 .session_state()?
246 .remote_identity_key()?
247 .ok_or(SignalProtocolError::InvalidSessionStructure)?;
248
249 if !identity_store
250 .is_trusted_identity(
251 remote_address,
252 &their_identity_key,
253 Direction::Receiving,
254 ctx,
255 )
256 .await?
257 {
258 log::warn!(
259 "Identity key {} is not trusted for remote address {}",
260 their_identity_key
261 .public_key()
262 .public_key_bytes()
263 .map_or_else(|e| format!("<error: {}>", e), hex::encode),
264 remote_address,
265 );
266 return Err(SignalProtocolError::UntrustedIdentity(
267 remote_address.clone(),
268 ));
269 }
270
271 identity_store
272 .save_identity(remote_address, &their_identity_key, ctx)
273 .await?;
274
275 session_store
276 .store_session(remote_address, &session_record, ctx)
277 .await?;
278
279 Ok(ptext)
280 }
281
create_decryption_failure_log( remote_address: &ProtocolAddress, mut errs: &[SignalProtocolError], record: &SessionRecord, ciphertext: &SignalMessage, ) -> Result<String>282 fn create_decryption_failure_log(
283 remote_address: &ProtocolAddress,
284 mut errs: &[SignalProtocolError],
285 record: &SessionRecord,
286 ciphertext: &SignalMessage,
287 ) -> Result<String> {
288 fn append_session_summary(
289 lines: &mut Vec<String>,
290 idx: usize,
291 state: Result<&SessionState>,
292 err: Option<&SignalProtocolError>,
293 ) {
294 let chains = state.and_then(|state| state.all_receiver_chain_logging_info());
295 match (err, &chains) {
296 (Some(err), Ok(chains)) => {
297 lines.push(format!(
298 "Candidate session {} failed with '{}', had {} receiver chains",
299 idx,
300 err,
301 chains.len()
302 ));
303 }
304 (Some(err), Err(state_err)) => {
305 lines.push(format!(
306 "Candidate session {} failed with '{}'; cannot get receiver chain info ({})",
307 idx, err, state_err,
308 ));
309 }
310 (None, Ok(chains)) => {
311 lines.push(format!(
312 "Candidate session {} had {} receiver chains",
313 idx,
314 chains.len()
315 ));
316 }
317 (None, Err(state_err)) => {
318 lines.push(format!(
319 "Candidate session {}: cannot get receiver chain info ({})",
320 idx, state_err,
321 ));
322 }
323 }
324
325 if let Ok(chains) = chains {
326 for chain in chains {
327 let chain_idx = match chain.1 {
328 Some(i) => i.to_string(),
329 None => "missing in protobuf".to_string(),
330 };
331
332 lines.push(format!(
333 "Receiver chain with sender ratchet public key {} chain key index {}",
334 hex::encode(chain.0),
335 chain_idx
336 ));
337 }
338 }
339 }
340
341 let mut lines = vec![];
342
343 lines.push(format!(
344 "Message from {} failed to decrypt; sender ratchet public key {} message counter {}",
345 remote_address,
346 hex::encode(ciphertext.sender_ratchet_key().public_key_bytes()?),
347 ciphertext.counter()
348 ));
349
350 if let Ok(current_session) = record.session_state() {
351 let err = errs.first();
352 if err.is_some() {
353 errs = &errs[1..];
354 }
355 append_session_summary(&mut lines, 0, Ok(current_session), err);
356 } else {
357 lines.push("No current session".to_string());
358 }
359
360 for (idx, (state, err)) in record
361 .previous_session_states()
362 .zip(errs.iter().map(Some).chain(std::iter::repeat(None)))
363 .enumerate()
364 {
365 let state = match state {
366 Ok(ref state) => Ok(state),
367 Err(err) => Err(err),
368 };
369 append_session_summary(&mut lines, idx + 1, state, err);
370 }
371
372 Ok(lines.join("\n"))
373 }
374
decrypt_message_with_record<R: Rng + CryptoRng>( remote_address: &ProtocolAddress, record: &mut SessionRecord, ciphertext: &SignalMessage, csprng: &mut R, ) -> Result<Vec<u8>>375 fn decrypt_message_with_record<R: Rng + CryptoRng>(
376 remote_address: &ProtocolAddress,
377 record: &mut SessionRecord,
378 ciphertext: &SignalMessage,
379 csprng: &mut R,
380 ) -> Result<Vec<u8>> {
381 let log_decryption_failure = |state: &SessionState, error: &SignalProtocolError| {
382 // A warning rather than an error because we try multiple sessions.
383 log::warn!(
384 "Failed to decrypt whisper message with ratchet key: {} and counter: {}. \
385 Session loaded for {}. Local session has base key: {} and counter: {}. {}",
386 ciphertext
387 .sender_ratchet_key()
388 .public_key_bytes()
389 .map_or_else(|e| format!("<error: {}>", e), hex::encode),
390 ciphertext.counter(),
391 remote_address,
392 state
393 .sender_ratchet_key_for_logging()
394 .unwrap_or_else(|e| format!("<error: {}>", e)),
395 state.previous_counter().unwrap_or(u32::MAX),
396 error
397 );
398 };
399
400 let mut errs = vec![];
401
402 if let Ok(current_state) = record.session_state() {
403 let mut current_state = current_state.clone();
404 let result =
405 decrypt_message_with_state(&mut current_state, ciphertext, remote_address, csprng);
406
407 match result {
408 Ok(ptext) => {
409 log::info!(
410 "decrypted message from {} with current session state (base key {})",
411 remote_address,
412 current_state
413 .sender_ratchet_key_for_logging()
414 .expect("successful decrypt always has a valid base key"),
415 );
416 record.set_session_state(current_state)?; // update the state
417 return Ok(ptext);
418 }
419 Err(SignalProtocolError::DuplicatedMessage(_, _)) => {
420 return result;
421 }
422 Err(e) => {
423 log_decryption_failure(¤t_state, &e);
424 errs.push(e);
425 }
426 }
427 }
428
429 // Try some old sessions:
430 let mut updated_session = None;
431
432 for (idx, previous) in record.previous_session_states().enumerate() {
433 let mut previous = previous?;
434
435 let result = decrypt_message_with_state(&mut previous, ciphertext, remote_address, csprng);
436
437 match result {
438 Ok(ptext) => {
439 log::info!(
440 "decrypted message from {} with PREVIOUS session state (base key {})",
441 remote_address,
442 previous
443 .sender_ratchet_key_for_logging()
444 .expect("successful decrypt always has a valid base key"),
445 );
446 updated_session = Some((ptext, idx, previous));
447 break;
448 }
449 Err(SignalProtocolError::DuplicatedMessage(_, _)) => {
450 return result;
451 }
452 Err(e) => {
453 log_decryption_failure(&previous, &e);
454 errs.push(e);
455 }
456 }
457 }
458
459 if let Some((ptext, idx, updated_session)) = updated_session {
460 record.promote_old_session(idx, updated_session)?;
461 Ok(ptext)
462 } else {
463 let previous_state_count = || record.previous_session_states().len();
464
465 if let Ok(current_state) = record.session_state() {
466 log::error!(
467 "No valid session for recipient: {}, current session base key {}, number of previous states: {}",
468 remote_address,
469 current_state.sender_ratchet_key_for_logging()
470 .unwrap_or_else(|e| format!("<error: {}>", e)),
471 previous_state_count(),
472 );
473 } else {
474 log::error!(
475 "No valid session for recipient: {}, (no current session state), number of previous states: {}",
476 remote_address,
477 previous_state_count(),
478 );
479 }
480 log::error!(
481 "{}",
482 create_decryption_failure_log(remote_address, &errs, record, ciphertext)?
483 );
484 Err(SignalProtocolError::InvalidMessage(
485 "Message decryption failed",
486 ))
487 }
488 }
489
decrypt_message_with_state<R: Rng + CryptoRng>( state: &mut SessionState, ciphertext: &SignalMessage, remote_address: &ProtocolAddress, csprng: &mut R, ) -> Result<Vec<u8>>490 fn decrypt_message_with_state<R: Rng + CryptoRng>(
491 state: &mut SessionState,
492 ciphertext: &SignalMessage,
493 remote_address: &ProtocolAddress,
494 csprng: &mut R,
495 ) -> Result<Vec<u8>> {
496 if !state.has_sender_chain()? {
497 return Err(SignalProtocolError::InvalidMessage(
498 "No session available to decrypt",
499 ));
500 }
501
502 let ciphertext_version = ciphertext.message_version() as u32;
503 if ciphertext_version != state.session_version()? {
504 return Err(SignalProtocolError::UnrecognizedMessageVersion(
505 ciphertext_version,
506 ));
507 }
508
509 let their_ephemeral = ciphertext.sender_ratchet_key();
510 let counter = ciphertext.counter();
511 let chain_key = get_or_create_chain_key(state, their_ephemeral, remote_address, csprng)?;
512 let message_keys =
513 get_or_create_message_key(state, their_ephemeral, remote_address, &chain_key, counter)?;
514
515 let their_identity_key = state
516 .remote_identity_key()?
517 .ok_or(SignalProtocolError::InvalidSessionStructure)?;
518
519 let mac_valid = ciphertext.verify_mac(
520 &their_identity_key,
521 &state.local_identity_key()?,
522 message_keys.mac_key(),
523 )?;
524
525 if !mac_valid {
526 return Err(SignalProtocolError::InvalidCiphertext);
527 }
528
529 let ptext = crypto::aes_256_cbc_decrypt(
530 ciphertext.body(),
531 message_keys.cipher_key(),
532 message_keys.iv(),
533 )?;
534
535 state.clear_unacknowledged_pre_key_message()?;
536
537 Ok(ptext)
538 }
539
get_or_create_chain_key<R: Rng + CryptoRng>( state: &mut SessionState, their_ephemeral: &PublicKey, remote_address: &ProtocolAddress, csprng: &mut R, ) -> Result<ChainKey>540 fn get_or_create_chain_key<R: Rng + CryptoRng>(
541 state: &mut SessionState,
542 their_ephemeral: &PublicKey,
543 remote_address: &ProtocolAddress,
544 csprng: &mut R,
545 ) -> Result<ChainKey> {
546 if let Some(chain) = state.get_receiver_chain_key(their_ephemeral)? {
547 log::debug!("{} has existing receiver chain.", remote_address);
548 return Ok(chain);
549 }
550
551 log::info!("{} creating new chains.", remote_address);
552
553 let root_key = state.root_key()?;
554 let our_ephemeral = state.sender_ratchet_private_key()?;
555 let receiver_chain = root_key.create_chain(their_ephemeral, &our_ephemeral)?;
556 let our_new_ephemeral = KeyPair::generate(csprng);
557 let sender_chain = receiver_chain
558 .0
559 .create_chain(their_ephemeral, &our_new_ephemeral.private_key)?;
560
561 state.set_root_key(&sender_chain.0)?;
562 state.add_receiver_chain(their_ephemeral, &receiver_chain.1)?;
563
564 let current_index = state.get_sender_chain_key()?.index();
565 let previous_index = if current_index > 0 {
566 current_index - 1
567 } else {
568 0
569 };
570 state.set_previous_counter(previous_index)?;
571 state.set_sender_chain(&our_new_ephemeral, &sender_chain.1)?;
572
573 Ok(receiver_chain.1)
574 }
575
get_or_create_message_key( state: &mut SessionState, their_ephemeral: &PublicKey, remote_address: &ProtocolAddress, chain_key: &ChainKey, counter: u32, ) -> Result<MessageKeys>576 fn get_or_create_message_key(
577 state: &mut SessionState,
578 their_ephemeral: &PublicKey,
579 remote_address: &ProtocolAddress,
580 chain_key: &ChainKey,
581 counter: u32,
582 ) -> Result<MessageKeys> {
583 let chain_index = chain_key.index();
584
585 if chain_index > counter {
586 return match state.get_message_keys(their_ephemeral, counter)? {
587 Some(keys) => Ok(keys),
588 None => {
589 log::info!(
590 "{} Duplicate message for counter: {}",
591 remote_address,
592 counter
593 );
594 Err(SignalProtocolError::DuplicatedMessage(chain_index, counter))
595 }
596 };
597 }
598
599 assert!(chain_index <= counter);
600
601 let jump = (counter - chain_index) as usize;
602
603 if jump > MAX_FORWARD_JUMPS {
604 if state.session_with_self()? {
605 log::info!(
606 "{} Jumping ahead {} messages (index: {}, counter: {})",
607 remote_address,
608 jump,
609 chain_index,
610 counter
611 );
612 } else {
613 log::error!(
614 "{} Exceeded future message limit: {}, index: {}, counter: {})",
615 remote_address,
616 MAX_FORWARD_JUMPS,
617 chain_index,
618 counter
619 );
620 return Err(SignalProtocolError::InvalidMessage(
621 "message from too far into the future",
622 ));
623 }
624 }
625
626 let mut chain_key = chain_key.clone();
627
628 while chain_key.index() < counter {
629 let message_keys = chain_key.message_keys()?;
630 state.set_message_keys(their_ephemeral, &message_keys)?;
631 chain_key = chain_key.next_chain_key()?;
632 }
633
634 state.set_receiver_chain_key(their_ephemeral, &chain_key.next_chain_key()?)?;
635 chain_key.message_keys()
636 }
637