1 #include "group_cipher.h"
2 
3 #include <assert.h>
4 #include <string.h>
5 #include "protocol.h"
6 #include "sender_key.h"
7 #include "sender_key_record.h"
8 #include "sender_key_state.h"
9 #include "signal_protocol_internal.h"
10 
11 struct group_cipher
12 {
13     signal_protocol_store_context *store;
14     const signal_protocol_sender_key_name *sender_key_id;
15     signal_context *global_context;
16     int (*decrypt_callback)(group_cipher *cipher, signal_buffer *plaintext, void *decrypt_context);
17     int inside_callback;
18     void *user_data;
19 };
20 
21 static int group_cipher_get_sender_key(group_cipher *cipher, sender_message_key **sender_key, sender_key_state *state, uint32_t iteration);
22 static int group_cipher_decrypt_callback(group_cipher *cipher, signal_buffer *plaintext, void *decrypt_context);
23 
group_cipher_create(group_cipher ** cipher,signal_protocol_store_context * store,const signal_protocol_sender_key_name * sender_key_id,signal_context * global_context)24 int group_cipher_create(group_cipher **cipher,
25         signal_protocol_store_context *store, const signal_protocol_sender_key_name *sender_key_id,
26         signal_context *global_context)
27 {
28     group_cipher *result_cipher;
29 
30     assert(store);
31     assert(global_context);
32 
33     result_cipher = malloc(sizeof(group_cipher));
34     if(!result_cipher) {
35         return SG_ERR_NOMEM;
36     }
37     memset(result_cipher, 0, sizeof(group_cipher));
38 
39     result_cipher->store = store;
40     result_cipher->sender_key_id = sender_key_id;
41     result_cipher->global_context = global_context;
42 
43     *cipher = result_cipher;
44     return 0;
45 }
46 
group_cipher_set_user_data(group_cipher * cipher,void * user_data)47 void group_cipher_set_user_data(group_cipher *cipher, void *user_data)
48 {
49     assert(cipher);
50     cipher->user_data = user_data;
51 }
52 
group_cipher_get_user_data(group_cipher * cipher)53 void *group_cipher_get_user_data(group_cipher *cipher)
54 {
55     assert(cipher);
56     return cipher->user_data;
57 }
58 
group_cipher_set_decryption_callback(group_cipher * cipher,int (* callback)(group_cipher * cipher,signal_buffer * plaintext,void * decrypt_context))59 void group_cipher_set_decryption_callback(group_cipher *cipher,
60         int (*callback)(group_cipher *cipher, signal_buffer *plaintext, void *decrypt_context))
61 {
62     assert(cipher);
63     cipher->decrypt_callback = callback;
64 }
65 
group_cipher_encrypt(group_cipher * cipher,const uint8_t * padded_plaintext,size_t padded_plaintext_len,ciphertext_message ** encrypted_message)66 int group_cipher_encrypt(group_cipher *cipher,
67         const uint8_t *padded_plaintext, size_t padded_plaintext_len,
68         ciphertext_message **encrypted_message)
69 {
70     int result = 0;
71     sender_key_message *result_message = 0;
72     sender_key_record *record = 0;
73     sender_key_state *state = 0;
74     ec_private_key *signing_key_private = 0;
75     sender_message_key *sender_key = 0;
76     sender_chain_key *next_chain_key = 0;
77     signal_buffer *sender_cipher_key = 0;
78     signal_buffer *sender_cipher_iv = 0;
79     signal_buffer *ciphertext = 0;
80 
81     assert(cipher);
82     signal_lock(cipher->global_context);
83 
84     if(cipher->inside_callback == 1) {
85         result = SG_ERR_INVAL;
86         goto complete;
87     }
88 
89     result = signal_protocol_sender_key_load_key(cipher->store, &record, cipher->sender_key_id);
90     if(result < 0) {
91         goto complete;
92     }
93 
94     result = sender_key_record_get_sender_key_state(record, &state);
95     if(result < 0) {
96         goto complete;
97     }
98 
99     signing_key_private = sender_key_state_get_signing_key_private(state);
100     if(!signing_key_private) {
101         result = SG_ERR_INVALID_KEY;
102         goto complete;
103     }
104 
105     result = sender_chain_key_create_message_key(sender_key_state_get_chain_key(state), &sender_key);
106     if(result < 0) {
107         goto complete;
108     }
109 
110     sender_cipher_key = sender_message_key_get_cipher_key(sender_key);
111     sender_cipher_iv = sender_message_key_get_iv(sender_key);
112 
113     result = signal_encrypt(cipher->global_context, &ciphertext, SG_CIPHER_AES_CBC_PKCS5,
114             signal_buffer_data(sender_cipher_key), signal_buffer_len(sender_cipher_key),
115             signal_buffer_data(sender_cipher_iv), signal_buffer_len(sender_cipher_iv),
116             padded_plaintext, padded_plaintext_len);
117     if(result < 0) {
118         goto complete;
119     }
120 
121     result = sender_key_message_create(&result_message,
122             sender_key_state_get_key_id(state),
123             sender_message_key_get_iteration(sender_key),
124             signal_buffer_data(ciphertext), signal_buffer_len(ciphertext),
125             signing_key_private,
126             cipher->global_context);
127     if(result < 0) {
128         goto complete;
129     }
130 
131     result = sender_chain_key_create_next(sender_key_state_get_chain_key(state), &next_chain_key);
132     if(result < 0) {
133         goto complete;
134     }
135 
136     sender_key_state_set_chain_key(state, next_chain_key);
137 
138     result = signal_protocol_sender_key_store_key(cipher->store, cipher->sender_key_id, record);
139 
140 complete:
141     if(result >= 0) {
142         *encrypted_message = (ciphertext_message *)result_message;
143     }
144     else {
145         if(result == SG_ERR_INVALID_KEY_ID) {
146             result = SG_ERR_NO_SESSION;
147         }
148         SIGNAL_UNREF(result_message);
149     }
150     signal_buffer_free(ciphertext);
151     SIGNAL_UNREF(next_chain_key);
152     SIGNAL_UNREF(sender_key);
153     SIGNAL_UNREF(record);
154     signal_unlock(cipher->global_context);
155     return result;
156 }
157 
group_cipher_decrypt(group_cipher * cipher,sender_key_message * ciphertext,void * decrypt_context,signal_buffer ** plaintext)158 int group_cipher_decrypt(group_cipher *cipher,
159         sender_key_message *ciphertext, void *decrypt_context,
160         signal_buffer **plaintext)
161 {
162     int result = 0;
163     signal_buffer *result_buf = 0;
164     sender_key_record *record = 0;
165     sender_key_state *state = 0;
166     sender_message_key *sender_key = 0;
167     signal_buffer *sender_cipher_key = 0;
168     signal_buffer *sender_cipher_iv = 0;
169     signal_buffer *ciphertext_body = 0;
170 
171     assert(cipher);
172     signal_lock(cipher->global_context);
173 
174     if(cipher->inside_callback == 1) {
175         result = SG_ERR_INVAL;
176         goto complete;
177     }
178 
179     result = signal_protocol_sender_key_load_key(cipher->store, &record, cipher->sender_key_id);
180     if(result < 0) {
181         goto complete;
182     }
183 
184     if(sender_key_record_is_empty(record)) {
185         result = SG_ERR_NO_SESSION;
186         signal_log(cipher->global_context, SG_LOG_WARNING, "No sender key for: %s::%s::%d",
187                 cipher->sender_key_id->group_id,
188                 cipher->sender_key_id->sender.name,
189                 cipher->sender_key_id->sender.device_id);
190         goto complete;
191     }
192 
193     result = sender_key_record_get_sender_key_state_by_id(record, &state, sender_key_message_get_key_id(ciphertext));
194     if(result < 0) {
195         goto complete;
196     }
197 
198     result = sender_key_message_verify_signature(ciphertext, sender_key_state_get_signing_key_public(state));
199     if(result < 0) {
200         goto complete;
201     }
202 
203     result = group_cipher_get_sender_key(cipher, &sender_key, state, sender_key_message_get_iteration(ciphertext));
204     if(result < 0) {
205         goto complete;
206     }
207 
208     sender_cipher_key = sender_message_key_get_cipher_key(sender_key);
209     sender_cipher_iv = sender_message_key_get_iv(sender_key);
210     ciphertext_body = sender_key_message_get_ciphertext(ciphertext);
211 
212     result = signal_decrypt(cipher->global_context, &result_buf, SG_CIPHER_AES_CBC_PKCS5,
213             signal_buffer_data(sender_cipher_key), signal_buffer_len(sender_cipher_key),
214             signal_buffer_data(sender_cipher_iv), signal_buffer_len(sender_cipher_iv),
215             signal_buffer_data(ciphertext_body), signal_buffer_len(ciphertext_body));
216     if(result < 0) {
217         goto complete;
218     }
219 
220     result = group_cipher_decrypt_callback(cipher, result_buf, decrypt_context);
221     if(result < 0) {
222         goto complete;
223     }
224 
225     result = signal_protocol_sender_key_store_key(cipher->store, cipher->sender_key_id, record);
226 
227 complete:
228     SIGNAL_UNREF(sender_key);
229     SIGNAL_UNREF(record);
230     if(result >= 0) {
231         *plaintext = result_buf;
232     }
233     else {
234         if(result == SG_ERR_INVALID_KEY || result == SG_ERR_INVALID_KEY_ID) {
235             result = SG_ERR_INVALID_MESSAGE;
236         }
237         signal_buffer_free(result_buf);
238     }
239     signal_unlock(cipher->global_context);
240     return result;
241 }
242 
group_cipher_get_sender_key(group_cipher * cipher,sender_message_key ** sender_key,sender_key_state * state,uint32_t iteration)243 int group_cipher_get_sender_key(group_cipher *cipher, sender_message_key **sender_key, sender_key_state *state, uint32_t iteration)
244 {
245     int result = 0;
246     sender_message_key *result_key = 0;
247     sender_chain_key *chain_key = 0;
248     sender_chain_key *next_chain_key = 0;
249     sender_message_key *message_key = 0;
250 
251     chain_key = sender_key_state_get_chain_key(state);
252     SIGNAL_REF(chain_key);
253 
254     if(sender_chain_key_get_iteration(chain_key) > iteration) {
255         if(sender_key_state_has_sender_message_key(state, iteration)) {
256             result_key = sender_key_state_remove_sender_message_key(state, iteration);
257             if(!result_key) {
258                 result = SG_ERR_UNKNOWN;
259             }
260             goto complete;
261         }
262         else {
263             result = SG_ERR_DUPLICATE_MESSAGE;
264             signal_log(cipher->global_context, SG_LOG_WARNING,
265                     "Received message with old counter: %d, %d",
266                     sender_chain_key_get_iteration(chain_key), iteration);
267             goto complete;
268         }
269     }
270 
271     if(iteration - sender_chain_key_get_iteration(chain_key) > 2000) {
272         result = SG_ERR_INVALID_MESSAGE;
273         signal_log(cipher->global_context, SG_LOG_WARNING, "Over 2000 messages into the future!");
274         goto complete;
275     }
276 
277     while(sender_chain_key_get_iteration(chain_key) < iteration) {
278         result = sender_chain_key_create_message_key(chain_key, &message_key);
279         if(result < 0) {
280             goto complete;
281         }
282 
283         result = sender_key_state_add_sender_message_key(state, message_key);
284         if(result < 0) {
285             goto complete;
286         }
287         SIGNAL_UNREF(message_key);
288 
289         result = sender_chain_key_create_next(chain_key, &next_chain_key);
290         if(result < 0) {
291             goto complete;
292         }
293 
294         SIGNAL_UNREF(chain_key);
295         chain_key = next_chain_key;
296         next_chain_key = 0;
297     }
298 
299     result = sender_chain_key_create_next(chain_key, &next_chain_key);
300     if(result < 0) {
301         goto complete;
302     }
303 
304     sender_key_state_set_chain_key(state, next_chain_key);
305     result = sender_chain_key_create_message_key(chain_key, &result_key);
306 
307 complete:
308     SIGNAL_UNREF(message_key);
309     SIGNAL_UNREF(chain_key);
310     SIGNAL_UNREF(next_chain_key);
311     if(result >= 0) {
312         *sender_key = result_key;
313     }
314     return result;
315 }
316 
group_cipher_decrypt_callback(group_cipher * cipher,signal_buffer * plaintext,void * decrypt_context)317 static int group_cipher_decrypt_callback(group_cipher *cipher, signal_buffer *plaintext, void *decrypt_context)
318 {
319     int result = 0;
320     if(cipher->decrypt_callback) {
321         cipher->inside_callback = 1;
322         result = cipher->decrypt_callback(cipher, plaintext, decrypt_context);
323         cipher->inside_callback = 0;
324     }
325     return result;
326 }
327 
group_cipher_free(group_cipher * cipher)328 void group_cipher_free(group_cipher *cipher)
329 {
330     if(cipher) {
331         free(cipher);
332     }
333 }
334