1 #include "group_cipher.h"
2
3 #include <assert.h>
4 #include <string.h>
5 #include "protocol.h"
6 #include "sender_key.h"
7 #include "sender_key_record.h"
8 #include "sender_key_state.h"
9 #include "signal_protocol_internal.h"
10
11 struct group_cipher
12 {
13 signal_protocol_store_context *store;
14 const signal_protocol_sender_key_name *sender_key_id;
15 signal_context *global_context;
16 int (*decrypt_callback)(group_cipher *cipher, signal_buffer *plaintext, void *decrypt_context);
17 int inside_callback;
18 void *user_data;
19 };
20
21 static int group_cipher_get_sender_key(group_cipher *cipher, sender_message_key **sender_key, sender_key_state *state, uint32_t iteration);
22 static int group_cipher_decrypt_callback(group_cipher *cipher, signal_buffer *plaintext, void *decrypt_context);
23
group_cipher_create(group_cipher ** cipher,signal_protocol_store_context * store,const signal_protocol_sender_key_name * sender_key_id,signal_context * global_context)24 int group_cipher_create(group_cipher **cipher,
25 signal_protocol_store_context *store, const signal_protocol_sender_key_name *sender_key_id,
26 signal_context *global_context)
27 {
28 group_cipher *result_cipher;
29
30 assert(store);
31 assert(global_context);
32
33 result_cipher = malloc(sizeof(group_cipher));
34 if(!result_cipher) {
35 return SG_ERR_NOMEM;
36 }
37 memset(result_cipher, 0, sizeof(group_cipher));
38
39 result_cipher->store = store;
40 result_cipher->sender_key_id = sender_key_id;
41 result_cipher->global_context = global_context;
42
43 *cipher = result_cipher;
44 return 0;
45 }
46
group_cipher_set_user_data(group_cipher * cipher,void * user_data)47 void group_cipher_set_user_data(group_cipher *cipher, void *user_data)
48 {
49 assert(cipher);
50 cipher->user_data = user_data;
51 }
52
group_cipher_get_user_data(group_cipher * cipher)53 void *group_cipher_get_user_data(group_cipher *cipher)
54 {
55 assert(cipher);
56 return cipher->user_data;
57 }
58
group_cipher_set_decryption_callback(group_cipher * cipher,int (* callback)(group_cipher * cipher,signal_buffer * plaintext,void * decrypt_context))59 void group_cipher_set_decryption_callback(group_cipher *cipher,
60 int (*callback)(group_cipher *cipher, signal_buffer *plaintext, void *decrypt_context))
61 {
62 assert(cipher);
63 cipher->decrypt_callback = callback;
64 }
65
group_cipher_encrypt(group_cipher * cipher,const uint8_t * padded_plaintext,size_t padded_plaintext_len,ciphertext_message ** encrypted_message)66 int group_cipher_encrypt(group_cipher *cipher,
67 const uint8_t *padded_plaintext, size_t padded_plaintext_len,
68 ciphertext_message **encrypted_message)
69 {
70 int result = 0;
71 sender_key_message *result_message = 0;
72 sender_key_record *record = 0;
73 sender_key_state *state = 0;
74 ec_private_key *signing_key_private = 0;
75 sender_message_key *sender_key = 0;
76 sender_chain_key *next_chain_key = 0;
77 signal_buffer *sender_cipher_key = 0;
78 signal_buffer *sender_cipher_iv = 0;
79 signal_buffer *ciphertext = 0;
80
81 assert(cipher);
82 signal_lock(cipher->global_context);
83
84 if(cipher->inside_callback == 1) {
85 result = SG_ERR_INVAL;
86 goto complete;
87 }
88
89 result = signal_protocol_sender_key_load_key(cipher->store, &record, cipher->sender_key_id);
90 if(result < 0) {
91 goto complete;
92 }
93
94 result = sender_key_record_get_sender_key_state(record, &state);
95 if(result < 0) {
96 goto complete;
97 }
98
99 signing_key_private = sender_key_state_get_signing_key_private(state);
100 if(!signing_key_private) {
101 result = SG_ERR_INVALID_KEY;
102 goto complete;
103 }
104
105 result = sender_chain_key_create_message_key(sender_key_state_get_chain_key(state), &sender_key);
106 if(result < 0) {
107 goto complete;
108 }
109
110 sender_cipher_key = sender_message_key_get_cipher_key(sender_key);
111 sender_cipher_iv = sender_message_key_get_iv(sender_key);
112
113 result = signal_encrypt(cipher->global_context, &ciphertext, SG_CIPHER_AES_CBC_PKCS5,
114 signal_buffer_data(sender_cipher_key), signal_buffer_len(sender_cipher_key),
115 signal_buffer_data(sender_cipher_iv), signal_buffer_len(sender_cipher_iv),
116 padded_plaintext, padded_plaintext_len);
117 if(result < 0) {
118 goto complete;
119 }
120
121 result = sender_key_message_create(&result_message,
122 sender_key_state_get_key_id(state),
123 sender_message_key_get_iteration(sender_key),
124 signal_buffer_data(ciphertext), signal_buffer_len(ciphertext),
125 signing_key_private,
126 cipher->global_context);
127 if(result < 0) {
128 goto complete;
129 }
130
131 result = sender_chain_key_create_next(sender_key_state_get_chain_key(state), &next_chain_key);
132 if(result < 0) {
133 goto complete;
134 }
135
136 sender_key_state_set_chain_key(state, next_chain_key);
137
138 result = signal_protocol_sender_key_store_key(cipher->store, cipher->sender_key_id, record);
139
140 complete:
141 if(result >= 0) {
142 *encrypted_message = (ciphertext_message *)result_message;
143 }
144 else {
145 if(result == SG_ERR_INVALID_KEY_ID) {
146 result = SG_ERR_NO_SESSION;
147 }
148 SIGNAL_UNREF(result_message);
149 }
150 signal_buffer_free(ciphertext);
151 SIGNAL_UNREF(next_chain_key);
152 SIGNAL_UNREF(sender_key);
153 SIGNAL_UNREF(record);
154 signal_unlock(cipher->global_context);
155 return result;
156 }
157
group_cipher_decrypt(group_cipher * cipher,sender_key_message * ciphertext,void * decrypt_context,signal_buffer ** plaintext)158 int group_cipher_decrypt(group_cipher *cipher,
159 sender_key_message *ciphertext, void *decrypt_context,
160 signal_buffer **plaintext)
161 {
162 int result = 0;
163 signal_buffer *result_buf = 0;
164 sender_key_record *record = 0;
165 sender_key_state *state = 0;
166 sender_message_key *sender_key = 0;
167 signal_buffer *sender_cipher_key = 0;
168 signal_buffer *sender_cipher_iv = 0;
169 signal_buffer *ciphertext_body = 0;
170
171 assert(cipher);
172 signal_lock(cipher->global_context);
173
174 if(cipher->inside_callback == 1) {
175 result = SG_ERR_INVAL;
176 goto complete;
177 }
178
179 result = signal_protocol_sender_key_load_key(cipher->store, &record, cipher->sender_key_id);
180 if(result < 0) {
181 goto complete;
182 }
183
184 if(sender_key_record_is_empty(record)) {
185 result = SG_ERR_NO_SESSION;
186 signal_log(cipher->global_context, SG_LOG_WARNING, "No sender key for: %s::%s::%d",
187 cipher->sender_key_id->group_id,
188 cipher->sender_key_id->sender.name,
189 cipher->sender_key_id->sender.device_id);
190 goto complete;
191 }
192
193 result = sender_key_record_get_sender_key_state_by_id(record, &state, sender_key_message_get_key_id(ciphertext));
194 if(result < 0) {
195 goto complete;
196 }
197
198 result = sender_key_message_verify_signature(ciphertext, sender_key_state_get_signing_key_public(state));
199 if(result < 0) {
200 goto complete;
201 }
202
203 result = group_cipher_get_sender_key(cipher, &sender_key, state, sender_key_message_get_iteration(ciphertext));
204 if(result < 0) {
205 goto complete;
206 }
207
208 sender_cipher_key = sender_message_key_get_cipher_key(sender_key);
209 sender_cipher_iv = sender_message_key_get_iv(sender_key);
210 ciphertext_body = sender_key_message_get_ciphertext(ciphertext);
211
212 result = signal_decrypt(cipher->global_context, &result_buf, SG_CIPHER_AES_CBC_PKCS5,
213 signal_buffer_data(sender_cipher_key), signal_buffer_len(sender_cipher_key),
214 signal_buffer_data(sender_cipher_iv), signal_buffer_len(sender_cipher_iv),
215 signal_buffer_data(ciphertext_body), signal_buffer_len(ciphertext_body));
216 if(result < 0) {
217 goto complete;
218 }
219
220 result = group_cipher_decrypt_callback(cipher, result_buf, decrypt_context);
221 if(result < 0) {
222 goto complete;
223 }
224
225 result = signal_protocol_sender_key_store_key(cipher->store, cipher->sender_key_id, record);
226
227 complete:
228 SIGNAL_UNREF(sender_key);
229 SIGNAL_UNREF(record);
230 if(result >= 0) {
231 *plaintext = result_buf;
232 }
233 else {
234 if(result == SG_ERR_INVALID_KEY || result == SG_ERR_INVALID_KEY_ID) {
235 result = SG_ERR_INVALID_MESSAGE;
236 }
237 signal_buffer_free(result_buf);
238 }
239 signal_unlock(cipher->global_context);
240 return result;
241 }
242
group_cipher_get_sender_key(group_cipher * cipher,sender_message_key ** sender_key,sender_key_state * state,uint32_t iteration)243 int group_cipher_get_sender_key(group_cipher *cipher, sender_message_key **sender_key, sender_key_state *state, uint32_t iteration)
244 {
245 int result = 0;
246 sender_message_key *result_key = 0;
247 sender_chain_key *chain_key = 0;
248 sender_chain_key *next_chain_key = 0;
249 sender_message_key *message_key = 0;
250
251 chain_key = sender_key_state_get_chain_key(state);
252 SIGNAL_REF(chain_key);
253
254 if(sender_chain_key_get_iteration(chain_key) > iteration) {
255 if(sender_key_state_has_sender_message_key(state, iteration)) {
256 result_key = sender_key_state_remove_sender_message_key(state, iteration);
257 if(!result_key) {
258 result = SG_ERR_UNKNOWN;
259 }
260 goto complete;
261 }
262 else {
263 result = SG_ERR_DUPLICATE_MESSAGE;
264 signal_log(cipher->global_context, SG_LOG_WARNING,
265 "Received message with old counter: %d, %d",
266 sender_chain_key_get_iteration(chain_key), iteration);
267 goto complete;
268 }
269 }
270
271 if(iteration - sender_chain_key_get_iteration(chain_key) > 2000) {
272 result = SG_ERR_INVALID_MESSAGE;
273 signal_log(cipher->global_context, SG_LOG_WARNING, "Over 2000 messages into the future!");
274 goto complete;
275 }
276
277 while(sender_chain_key_get_iteration(chain_key) < iteration) {
278 result = sender_chain_key_create_message_key(chain_key, &message_key);
279 if(result < 0) {
280 goto complete;
281 }
282
283 result = sender_key_state_add_sender_message_key(state, message_key);
284 if(result < 0) {
285 goto complete;
286 }
287 SIGNAL_UNREF(message_key);
288
289 result = sender_chain_key_create_next(chain_key, &next_chain_key);
290 if(result < 0) {
291 goto complete;
292 }
293
294 SIGNAL_UNREF(chain_key);
295 chain_key = next_chain_key;
296 next_chain_key = 0;
297 }
298
299 result = sender_chain_key_create_next(chain_key, &next_chain_key);
300 if(result < 0) {
301 goto complete;
302 }
303
304 sender_key_state_set_chain_key(state, next_chain_key);
305 result = sender_chain_key_create_message_key(chain_key, &result_key);
306
307 complete:
308 SIGNAL_UNREF(message_key);
309 SIGNAL_UNREF(chain_key);
310 SIGNAL_UNREF(next_chain_key);
311 if(result >= 0) {
312 *sender_key = result_key;
313 }
314 return result;
315 }
316
group_cipher_decrypt_callback(group_cipher * cipher,signal_buffer * plaintext,void * decrypt_context)317 static int group_cipher_decrypt_callback(group_cipher *cipher, signal_buffer *plaintext, void *decrypt_context)
318 {
319 int result = 0;
320 if(cipher->decrypt_callback) {
321 cipher->inside_callback = 1;
322 result = cipher->decrypt_callback(cipher, plaintext, decrypt_context);
323 cipher->inside_callback = 0;
324 }
325 return result;
326 }
327
group_cipher_free(group_cipher * cipher)328 void group_cipher_free(group_cipher *cipher)
329 {
330 if(cipher) {
331 free(cipher);
332 }
333 }
334