1 //
2 // Copyright 2021 Signal Messenger, LLC.
3 // SPDX-License-Identifier: AGPL-3.0-only
4 //
5
6 use crate::{Aes256Ctr32, Error, Result};
7 use aes::{Aes256, BlockEncrypt, NewBlockCipher};
8 use generic_array::GenericArray;
9 use ghash::universal_hash::{NewUniversalHash, UniversalHash};
10 use ghash::GHash;
11 use subtle::ConstantTimeEq;
12
13 pub const TAG_SIZE: usize = 16;
14 pub const NONCE_SIZE: usize = 12;
15
16 #[derive(Clone)]
17 struct GcmGhash {
18 ghash: GHash,
19 ghash_pad: [u8; TAG_SIZE],
20 msg_buf: [u8; TAG_SIZE],
21 msg_buf_offset: usize,
22 ad_len: usize,
23 msg_len: usize,
24 }
25
26 impl GcmGhash {
new(h: &[u8; TAG_SIZE], ghash_pad: [u8; TAG_SIZE], associated_data: &[u8]) -> Result<Self>27 fn new(h: &[u8; TAG_SIZE], ghash_pad: [u8; TAG_SIZE], associated_data: &[u8]) -> Result<Self> {
28 let mut ghash = GHash::new(h.into());
29
30 ghash.update_padded(associated_data);
31
32 Ok(Self {
33 ghash,
34 ghash_pad,
35 msg_buf: [0u8; TAG_SIZE],
36 msg_buf_offset: 0,
37 ad_len: associated_data.len(),
38 msg_len: 0,
39 })
40 }
41
update(&mut self, msg: &[u8]) -> Result<()>42 fn update(&mut self, msg: &[u8]) -> Result<()> {
43 if self.msg_buf_offset > 0 {
44 let taking = std::cmp::min(msg.len(), TAG_SIZE - self.msg_buf_offset);
45 self.msg_buf[self.msg_buf_offset..self.msg_buf_offset + taking]
46 .copy_from_slice(&msg[..taking]);
47 self.msg_buf_offset += taking;
48 assert!(self.msg_buf_offset <= TAG_SIZE);
49
50 self.msg_len += taking;
51
52 if self.msg_buf_offset == TAG_SIZE {
53 self.ghash.update(&self.msg_buf.into());
54 self.msg_buf_offset = 0;
55 return self.update(&msg[taking..]);
56 } else {
57 return Ok(());
58 }
59 }
60
61 self.msg_len += msg.len();
62
63 assert_eq!(self.msg_buf_offset, 0);
64 let full_blocks = msg.len() / 16;
65 let leftover = msg.len() - 16 * full_blocks;
66 assert!(leftover < TAG_SIZE);
67 if full_blocks > 0 {
68 for block in msg[..full_blocks * 16].chunks_exact(16) {
69 self.ghash.update(block.into());
70 }
71 }
72
73 self.msg_buf[0..leftover].copy_from_slice(&msg[full_blocks * 16..]);
74 self.msg_buf_offset = leftover;
75 assert!(self.msg_buf_offset < TAG_SIZE);
76
77 Ok(())
78 }
79
finalize(mut self) -> Result<[u8; TAG_SIZE]>80 fn finalize(mut self) -> Result<[u8; TAG_SIZE]> {
81 if self.msg_buf_offset > 0 {
82 self.ghash
83 .update_padded(&self.msg_buf[..self.msg_buf_offset]);
84 }
85
86 let mut final_block = [0u8; 16];
87 final_block[..8].copy_from_slice(&(8 * self.ad_len as u64).to_be_bytes());
88 final_block[8..].copy_from_slice(&(8 * self.msg_len as u64).to_be_bytes());
89
90 self.ghash.update(&final_block.into());
91 let mut hash = self.ghash.finalize().into_bytes();
92
93 for (i, b) in hash.iter_mut().enumerate() {
94 *b ^= self.ghash_pad[i];
95 }
96
97 Ok(hash.into())
98 }
99 }
100
setup_gcm(key: &[u8], nonce: &[u8], associated_data: &[u8]) -> Result<(Aes256Ctr32, GcmGhash)>101 fn setup_gcm(key: &[u8], nonce: &[u8], associated_data: &[u8]) -> Result<(Aes256Ctr32, GcmGhash)> {
102 /*
103 GCM supports other sizes but 12 bytes is standard and other
104 sizes require special handling
105 */
106 if nonce.len() != NONCE_SIZE {
107 return Err(Error::InvalidNonceSize);
108 }
109
110 let aes256 = Aes256::new_from_slice(key).map_err(|_| Error::InvalidKeySize)?;
111 let mut h = [0u8; TAG_SIZE];
112 aes256.encrypt_block(GenericArray::from_mut_slice(&mut h));
113
114 let mut ctr = Aes256Ctr32::new(aes256, nonce, 1)?;
115
116 let mut ghash_pad = [0u8; 16];
117 ctr.process(&mut ghash_pad)?;
118
119 let ghash = GcmGhash::new(&h, ghash_pad, associated_data)?;
120 Ok((ctr, ghash))
121 }
122
123 pub struct Aes256GcmEncryption {
124 ctr: Aes256Ctr32,
125 ghash: GcmGhash,
126 }
127
128 impl Aes256GcmEncryption {
129 pub const TAG_SIZE: usize = TAG_SIZE;
130 pub const NONCE_SIZE: usize = NONCE_SIZE;
131
new(key: &[u8], nonce: &[u8], associated_data: &[u8]) -> Result<Self>132 pub fn new(key: &[u8], nonce: &[u8], associated_data: &[u8]) -> Result<Self> {
133 let (ctr, ghash) = setup_gcm(key, nonce, associated_data)?;
134 Ok(Self { ctr, ghash })
135 }
136
encrypt(&mut self, buf: &mut [u8]) -> Result<()>137 pub fn encrypt(&mut self, buf: &mut [u8]) -> Result<()> {
138 self.ctr.process(buf)?;
139 self.ghash.update(buf)?;
140 Ok(())
141 }
142
compute_tag(self) -> Result<[u8; TAG_SIZE]>143 pub fn compute_tag(self) -> Result<[u8; TAG_SIZE]> {
144 self.ghash.finalize()
145 }
146 }
147
148 pub struct Aes256GcmDecryption {
149 ctr: Aes256Ctr32,
150 ghash: GcmGhash,
151 }
152
153 impl Aes256GcmDecryption {
154 pub const TAG_SIZE: usize = TAG_SIZE;
155 pub const NONCE_SIZE: usize = NONCE_SIZE;
156
new(key: &[u8], nonce: &[u8], associated_data: &[u8]) -> Result<Self>157 pub fn new(key: &[u8], nonce: &[u8], associated_data: &[u8]) -> Result<Self> {
158 let (ctr, ghash) = setup_gcm(key, nonce, associated_data)?;
159 Ok(Self { ctr, ghash })
160 }
161
decrypt(&mut self, buf: &mut [u8]) -> Result<()>162 pub fn decrypt(&mut self, buf: &mut [u8]) -> Result<()> {
163 self.ghash.update(buf)?;
164 self.ctr.process(buf)?;
165 Ok(())
166 }
167
verify_tag(self, tag: &[u8]) -> Result<()>168 pub fn verify_tag(self, tag: &[u8]) -> Result<()> {
169 if tag.len() != TAG_SIZE {
170 return Err(Error::InvalidTag);
171 }
172
173 let computed_tag = self.ghash.finalize()?;
174
175 let tag_ok = tag.ct_eq(&computed_tag);
176
177 if !bool::from(tag_ok) {
178 return Err(Error::InvalidTag);
179 }
180
181 Ok(())
182 }
183 }
184