1 //
2 // Copyright 2021 Signal Messenger, LLC.
3 // SPDX-License-Identifier: AGPL-3.0-only
4 //
5 
6 use crate::{Aes256Ctr32, Error, Result};
7 use aes::{Aes256, BlockEncrypt, NewBlockCipher};
8 use generic_array::GenericArray;
9 use ghash::universal_hash::{NewUniversalHash, UniversalHash};
10 use ghash::GHash;
11 use subtle::ConstantTimeEq;
12 
13 pub const TAG_SIZE: usize = 16;
14 pub const NONCE_SIZE: usize = 12;
15 
16 #[derive(Clone)]
17 struct GcmGhash {
18     ghash: GHash,
19     ghash_pad: [u8; TAG_SIZE],
20     msg_buf: [u8; TAG_SIZE],
21     msg_buf_offset: usize,
22     ad_len: usize,
23     msg_len: usize,
24 }
25 
26 impl GcmGhash {
new(h: &[u8; TAG_SIZE], ghash_pad: [u8; TAG_SIZE], associated_data: &[u8]) -> Result<Self>27     fn new(h: &[u8; TAG_SIZE], ghash_pad: [u8; TAG_SIZE], associated_data: &[u8]) -> Result<Self> {
28         let mut ghash = GHash::new(h.into());
29 
30         ghash.update_padded(associated_data);
31 
32         Ok(Self {
33             ghash,
34             ghash_pad,
35             msg_buf: [0u8; TAG_SIZE],
36             msg_buf_offset: 0,
37             ad_len: associated_data.len(),
38             msg_len: 0,
39         })
40     }
41 
update(&mut self, msg: &[u8]) -> Result<()>42     fn update(&mut self, msg: &[u8]) -> Result<()> {
43         if self.msg_buf_offset > 0 {
44             let taking = std::cmp::min(msg.len(), TAG_SIZE - self.msg_buf_offset);
45             self.msg_buf[self.msg_buf_offset..self.msg_buf_offset + taking]
46                 .copy_from_slice(&msg[..taking]);
47             self.msg_buf_offset += taking;
48             assert!(self.msg_buf_offset <= TAG_SIZE);
49 
50             self.msg_len += taking;
51 
52             if self.msg_buf_offset == TAG_SIZE {
53                 self.ghash.update(&self.msg_buf.into());
54                 self.msg_buf_offset = 0;
55                 return self.update(&msg[taking..]);
56             } else {
57                 return Ok(());
58             }
59         }
60 
61         self.msg_len += msg.len();
62 
63         assert_eq!(self.msg_buf_offset, 0);
64         let full_blocks = msg.len() / 16;
65         let leftover = msg.len() - 16 * full_blocks;
66         assert!(leftover < TAG_SIZE);
67         if full_blocks > 0 {
68             for block in msg[..full_blocks * 16].chunks_exact(16) {
69                 self.ghash.update(block.into());
70             }
71         }
72 
73         self.msg_buf[0..leftover].copy_from_slice(&msg[full_blocks * 16..]);
74         self.msg_buf_offset = leftover;
75         assert!(self.msg_buf_offset < TAG_SIZE);
76 
77         Ok(())
78     }
79 
finalize(mut self) -> Result<[u8; TAG_SIZE]>80     fn finalize(mut self) -> Result<[u8; TAG_SIZE]> {
81         if self.msg_buf_offset > 0 {
82             self.ghash
83                 .update_padded(&self.msg_buf[..self.msg_buf_offset]);
84         }
85 
86         let mut final_block = [0u8; 16];
87         final_block[..8].copy_from_slice(&(8 * self.ad_len as u64).to_be_bytes());
88         final_block[8..].copy_from_slice(&(8 * self.msg_len as u64).to_be_bytes());
89 
90         self.ghash.update(&final_block.into());
91         let mut hash = self.ghash.finalize().into_bytes();
92 
93         for (i, b) in hash.iter_mut().enumerate() {
94             *b ^= self.ghash_pad[i];
95         }
96 
97         Ok(hash.into())
98     }
99 }
100 
setup_gcm(key: &[u8], nonce: &[u8], associated_data: &[u8]) -> Result<(Aes256Ctr32, GcmGhash)>101 fn setup_gcm(key: &[u8], nonce: &[u8], associated_data: &[u8]) -> Result<(Aes256Ctr32, GcmGhash)> {
102     /*
103     GCM supports other sizes but 12 bytes is standard and other
104     sizes require special handling
105      */
106     if nonce.len() != NONCE_SIZE {
107         return Err(Error::InvalidNonceSize);
108     }
109 
110     let aes256 = Aes256::new_from_slice(key).map_err(|_| Error::InvalidKeySize)?;
111     let mut h = [0u8; TAG_SIZE];
112     aes256.encrypt_block(GenericArray::from_mut_slice(&mut h));
113 
114     let mut ctr = Aes256Ctr32::new(aes256, nonce, 1)?;
115 
116     let mut ghash_pad = [0u8; 16];
117     ctr.process(&mut ghash_pad)?;
118 
119     let ghash = GcmGhash::new(&h, ghash_pad, associated_data)?;
120     Ok((ctr, ghash))
121 }
122 
123 pub struct Aes256GcmEncryption {
124     ctr: Aes256Ctr32,
125     ghash: GcmGhash,
126 }
127 
128 impl Aes256GcmEncryption {
129     pub const TAG_SIZE: usize = TAG_SIZE;
130     pub const NONCE_SIZE: usize = NONCE_SIZE;
131 
new(key: &[u8], nonce: &[u8], associated_data: &[u8]) -> Result<Self>132     pub fn new(key: &[u8], nonce: &[u8], associated_data: &[u8]) -> Result<Self> {
133         let (ctr, ghash) = setup_gcm(key, nonce, associated_data)?;
134         Ok(Self { ctr, ghash })
135     }
136 
encrypt(&mut self, buf: &mut [u8]) -> Result<()>137     pub fn encrypt(&mut self, buf: &mut [u8]) -> Result<()> {
138         self.ctr.process(buf)?;
139         self.ghash.update(buf)?;
140         Ok(())
141     }
142 
compute_tag(self) -> Result<[u8; TAG_SIZE]>143     pub fn compute_tag(self) -> Result<[u8; TAG_SIZE]> {
144         self.ghash.finalize()
145     }
146 }
147 
148 pub struct Aes256GcmDecryption {
149     ctr: Aes256Ctr32,
150     ghash: GcmGhash,
151 }
152 
153 impl Aes256GcmDecryption {
154     pub const TAG_SIZE: usize = TAG_SIZE;
155     pub const NONCE_SIZE: usize = NONCE_SIZE;
156 
new(key: &[u8], nonce: &[u8], associated_data: &[u8]) -> Result<Self>157     pub fn new(key: &[u8], nonce: &[u8], associated_data: &[u8]) -> Result<Self> {
158         let (ctr, ghash) = setup_gcm(key, nonce, associated_data)?;
159         Ok(Self { ctr, ghash })
160     }
161 
decrypt(&mut self, buf: &mut [u8]) -> Result<()>162     pub fn decrypt(&mut self, buf: &mut [u8]) -> Result<()> {
163         self.ghash.update(buf)?;
164         self.ctr.process(buf)?;
165         Ok(())
166     }
167 
verify_tag(self, tag: &[u8]) -> Result<()>168     pub fn verify_tag(self, tag: &[u8]) -> Result<()> {
169         if tag.len() != TAG_SIZE {
170             return Err(Error::InvalidTag);
171         }
172 
173         let computed_tag = self.ghash.finalize()?;
174 
175         let tag_ok = tag.ct_eq(&computed_tag);
176 
177         if !bool::from(tag_ok) {
178             return Err(Error::InvalidTag);
179         }
180 
181         Ok(())
182     }
183 }
184