1 #include "signal_protocol.h"
2 
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <stdint.h>
6 #include <string.h>
7 #include <stdarg.h>
8 #include <assert.h>
9 
10 #include "signal_protocol_internal.h"
11 #include "signal_utarray.h"
12 
13 #ifdef _WINDOWS
14 #include "Windows.h"
15 #include "WinBase.h"
16 #endif
17 
18 #ifdef DEBUG_REFCOUNT
19 int type_ref_count = 0;
20 int type_unref_count = 0;
21 #endif
22 
23 #define MIN(a,b) (((a)<(b))?(a):(b))
24 
25 struct signal_protocol_store_context {
26     signal_context *global_context;
27     signal_protocol_session_store session_store;
28     signal_protocol_pre_key_store pre_key_store;
29     signal_protocol_signed_pre_key_store signed_pre_key_store;
30     signal_protocol_identity_key_store identity_key_store;
31     signal_protocol_sender_key_store sender_key_store;
32 };
33 
signal_type_init(signal_type_base * instance,void (* destroy_func)(signal_type_base * instance))34 void signal_type_init(signal_type_base *instance,
35         void (*destroy_func)(signal_type_base *instance))
36 {
37     instance->ref_count = 1;
38     instance->destroy = destroy_func;
39 #ifdef DEBUG_REFCOUNT
40     type_ref_count++;
41 #endif
42 }
43 
signal_type_ref(signal_type_base * instance)44 void signal_type_ref(signal_type_base *instance)
45 {
46 #ifdef DEBUG_REFCOUNT
47     type_ref_count++;
48 #endif
49     assert(instance);
50     assert(instance->ref_count > 0);
51     instance->ref_count++;
52 }
53 
signal_type_unref(signal_type_base * instance)54 void signal_type_unref(signal_type_base *instance)
55 {
56     if(instance) {
57 #ifdef DEBUG_REFCOUNT
58     type_unref_count++;
59 #endif
60         assert(instance->ref_count > 0);
61         if(instance->ref_count > 1) {
62             instance->ref_count--;
63         }
64         else {
65             instance->destroy(instance);
66         }
67     }
68 }
69 
70 #ifdef DEBUG_REFCOUNT
signal_type_ref_count(signal_type_base * instance)71 int signal_type_ref_count(signal_type_base *instance)
72 {
73     if(!instance) {
74         return 0;
75     }
76     return instance->ref_count;
77 }
78 #endif
79 
80 /*------------------------------------------------------------------------*/
81 
signal_buffer_alloc(size_t len)82 signal_buffer *signal_buffer_alloc(size_t len)
83 {
84     signal_buffer *buffer;
85     if(len > (SIZE_MAX - sizeof(struct signal_buffer)) / sizeof(uint8_t)) {
86         return 0;
87     }
88 
89     buffer = malloc(sizeof(struct signal_buffer) + (sizeof(uint8_t) * len));
90     if(buffer) {
91         buffer->len = len;
92     }
93     return buffer;
94 }
95 
signal_buffer_create(const uint8_t * data,size_t len)96 signal_buffer *signal_buffer_create(const uint8_t *data, size_t len)
97 {
98     signal_buffer *buffer = signal_buffer_alloc(len);
99     if(!buffer) {
100         return 0;
101     }
102 
103     memcpy(buffer->data, data, len);
104     return buffer;
105 }
106 
signal_buffer_copy(const signal_buffer * buffer)107 signal_buffer *signal_buffer_copy(const signal_buffer *buffer)
108 {
109     return signal_buffer_create(buffer->data, buffer->len);
110 }
111 
signal_buffer_n_copy(const signal_buffer * buffer,size_t n)112 signal_buffer *signal_buffer_n_copy(const signal_buffer *buffer, size_t n)
113 {
114     size_t len = MIN(buffer->len, n);
115     return signal_buffer_create(buffer->data, len);
116 }
117 
signal_buffer_append(signal_buffer * buffer,const uint8_t * data,size_t len)118 signal_buffer *signal_buffer_append(signal_buffer *buffer, const uint8_t *data, size_t len)
119 {
120     signal_buffer *tmp_buffer;
121     size_t previous_size = buffer->len;
122     size_t previous_alloc = sizeof(struct signal_buffer) + (sizeof(uint8_t) * previous_size);
123 
124     if(len > (SIZE_MAX - previous_alloc)) {
125         return 0;
126     }
127 
128     tmp_buffer = realloc(buffer, previous_alloc + (sizeof(uint8_t) * len));
129     if(!tmp_buffer) {
130         return 0;
131     }
132 
133     memcpy(tmp_buffer->data + previous_size, data, len);
134     tmp_buffer->len = previous_size + len;
135     return tmp_buffer;
136 }
137 
signal_buffer_data(signal_buffer * buffer)138 uint8_t *signal_buffer_data(signal_buffer *buffer)
139 {
140     return buffer->data;
141 }
142 
signal_buffer_const_data(const signal_buffer * buffer)143 const uint8_t *signal_buffer_const_data(const signal_buffer *buffer)
144 {
145     return buffer->data;
146 }
147 
signal_buffer_len(const signal_buffer * buffer)148 size_t signal_buffer_len(const signal_buffer *buffer)
149 {
150     return buffer->len;
151 }
152 
signal_buffer_compare(signal_buffer * buffer1,signal_buffer * buffer2)153 int signal_buffer_compare(signal_buffer *buffer1, signal_buffer *buffer2)
154 {
155     if(buffer1 == buffer2) {
156         return 0;
157     }
158     else if(buffer1 == 0 && buffer2 != 0) {
159         return -1;
160     }
161     else if(buffer1 != 0 && buffer2 == 0) {
162         return 1;
163     }
164     else {
165         if(buffer1->len < buffer2->len) {
166             return -1;
167         }
168         else if(buffer1->len > buffer2->len) {
169             return 1;
170         }
171         else {
172             return signal_constant_memcmp(buffer1->data, buffer2->data, buffer1->len);
173         }
174     }
175 }
176 
signal_buffer_free(signal_buffer * buffer)177 void signal_buffer_free(signal_buffer *buffer)
178 {
179     if(buffer) {
180         free(buffer);
181     }
182 }
183 
signal_buffer_bzero_free(signal_buffer * buffer)184 void signal_buffer_bzero_free(signal_buffer *buffer)
185 {
186     if(buffer) {
187         signal_explicit_bzero(buffer->data, buffer->len);
188         free(buffer);
189     }
190 }
191 
192 /*------------------------------------------------------------------------*/
193 
194 struct signal_buffer_list
195 {
196     UT_array *values;
197 };
198 
signal_buffer_list_alloc(void)199 signal_buffer_list *signal_buffer_list_alloc(void)
200 {
201     int result = 0;
202     signal_buffer_list *list = malloc(sizeof(signal_buffer_list));
203     if(!list) {
204         result = SG_ERR_NOMEM;
205         goto complete;
206     }
207 
208     memset(list, 0, sizeof(signal_buffer_list));
209 
210     utarray_new(list->values, &ut_ptr_icd);
211 
212 complete:
213     if(result < 0) {
214         if(list) {
215             free(list);
216         }
217         return 0;
218     }
219     else {
220         return list;
221     }
222 }
223 
signal_buffer_list_copy(const signal_buffer_list * list)224 signal_buffer_list *signal_buffer_list_copy(const signal_buffer_list *list)
225 {
226     int result = 0;
227     signal_buffer_list *result_list = 0;
228     signal_buffer *buffer_copy = 0;
229     unsigned int list_size;
230     unsigned int i;
231 
232     result_list = signal_buffer_list_alloc();
233     if(!result_list) {
234         result = SG_ERR_NOMEM;
235         goto complete;
236     }
237 
238     list_size = utarray_len(list->values);
239 
240     utarray_reserve(result_list->values, list_size);
241 
242     for(i = 0; i < list_size; i++) {
243         signal_buffer **buffer = (signal_buffer**)utarray_eltptr(list->values, i);
244         buffer_copy = signal_buffer_copy(*buffer);
245         utarray_push_back(result_list->values, &buffer_copy);
246         buffer_copy = 0;
247     }
248 
249 complete:
250     if(result < 0) {
251         signal_buffer_free(buffer_copy);
252         signal_buffer_list_free(result_list);
253         return 0;
254     }
255     else {
256         return result_list;
257     }
258 }
259 
signal_buffer_list_push_back(signal_buffer_list * list,signal_buffer * buffer)260 int signal_buffer_list_push_back(signal_buffer_list *list, signal_buffer *buffer)
261 {
262     int result = 0;
263     assert(list);
264     utarray_push_back(list->values, &buffer);
265 
266 complete:
267     return result;
268 }
269 
signal_buffer_list_size(signal_buffer_list * list)270 unsigned int signal_buffer_list_size(signal_buffer_list *list)
271 {
272     assert(list);
273     return utarray_len(list->values);
274 }
275 
signal_buffer_list_at(signal_buffer_list * list,unsigned int index)276 signal_buffer *signal_buffer_list_at(signal_buffer_list *list, unsigned int index)
277 {
278     signal_buffer **value = 0;
279 
280     assert(list);
281     assert(index < utarray_len(list->values));
282 
283     value = (signal_buffer**)utarray_eltptr(list->values, index);
284 
285     assert(*value);
286 
287     return *value;
288 }
289 
signal_buffer_list_free(signal_buffer_list * list)290 void signal_buffer_list_free(signal_buffer_list *list)
291 {
292     unsigned int size;
293     unsigned int i;
294     signal_buffer **p;
295     if(list) {
296         size = utarray_len(list->values);
297         for (i = 0; i < size; i++) {
298             p = (signal_buffer **)utarray_eltptr(list->values, i);
299             signal_buffer_free(*p);
300         }
301         utarray_free(list->values);
302         free(list);
303     }
304 }
305 
signal_buffer_list_bzero_free(signal_buffer_list * list)306 void signal_buffer_list_bzero_free(signal_buffer_list *list)
307 {
308     unsigned int size;
309     unsigned int i;
310     signal_buffer **p;
311     if(list) {
312         size = utarray_len(list->values);
313         for (i = 0; i < size; i++) {
314             p = (signal_buffer **)utarray_eltptr(list->values, i);
315             signal_buffer_bzero_free(*p);
316         }
317         utarray_free(list->values);
318         free(list);
319     }
320 }
321 
322 /*------------------------------------------------------------------------*/
323 
324 struct signal_int_list
325 {
326     UT_array *values;
327 };
328 
signal_int_list_alloc()329 signal_int_list *signal_int_list_alloc()
330 {
331     int result = 0;
332     signal_int_list *list = malloc(sizeof(signal_int_list));
333     if(!list) {
334         result = SG_ERR_NOMEM;
335         goto complete;
336     }
337 
338     memset(list, 0, sizeof(signal_int_list));
339 
340     utarray_new(list->values, &ut_int_icd);
341 
342 complete:
343     if(result < 0) {
344         if(list) {
345             free(list);
346         }
347         return 0;
348     }
349     else {
350         return list;
351     }
352 }
353 
signal_int_list_push_back(signal_int_list * list,int value)354 int signal_int_list_push_back(signal_int_list *list, int value)
355 {
356     int result = 0;
357     assert(list);
358     utarray_push_back(list->values, &value);
359 
360 complete:
361     return result;
362 }
363 
signal_int_list_size(signal_int_list * list)364 unsigned int signal_int_list_size(signal_int_list *list)
365 {
366     assert(list);
367     return utarray_len(list->values);
368 }
369 
signal_int_list_at(signal_int_list * list,unsigned int index)370 int signal_int_list_at(signal_int_list *list, unsigned int index)
371 {
372     int *value = 0;
373 
374     assert(list);
375     assert(index < utarray_len(list->values));
376 
377     value = (int *)utarray_eltptr(list->values, index);
378 
379     assert(value);
380 
381     return *value;
382 }
383 
signal_int_list_free(signal_int_list * list)384 void signal_int_list_free(signal_int_list *list)
385 {
386     if(list) {
387         utarray_free(list->values);
388         free(list);
389     }
390 }
391 
392 /*------------------------------------------------------------------------*/
393 
signal_context_create(signal_context ** context,void * user_data)394 int signal_context_create(signal_context **context, void *user_data)
395 {
396     *context = malloc(sizeof(signal_context));
397     if(!(*context)) {
398         return SG_ERR_NOMEM;
399     }
400     memset(*context, 0, sizeof(signal_context));
401     (*context)->user_data = user_data;
402 #ifdef DEBUG_REFCOUNT
403     type_ref_count = 0;
404     type_unref_count = 0;
405 #endif
406     return 0;
407 }
408 
signal_context_set_crypto_provider(signal_context * context,const signal_crypto_provider * crypto_provider)409 int signal_context_set_crypto_provider(signal_context *context, const signal_crypto_provider *crypto_provider)
410 {
411     assert(context);
412     if(!crypto_provider
413             || !crypto_provider->hmac_sha256_init_func
414             || !crypto_provider->hmac_sha256_update_func
415             || !crypto_provider->hmac_sha256_final_func
416             || !crypto_provider->hmac_sha256_cleanup_func) {
417         return SG_ERR_INVAL;
418     }
419     memcpy(&(context->crypto_provider), crypto_provider, sizeof(signal_crypto_provider));
420     return 0;
421 }
422 
signal_context_set_locking_functions(signal_context * context,void (* lock)(void * user_data),void (* unlock)(void * user_data))423 int signal_context_set_locking_functions(signal_context *context,
424         void (*lock)(void *user_data), void (*unlock)(void *user_data))
425 {
426     assert(context);
427     if((lock && !unlock) || (!lock && unlock)) {
428         return SG_ERR_INVAL;
429     }
430 
431     context->lock = lock;
432     context->unlock = unlock;
433     return 0;
434 }
435 
signal_context_set_log_function(signal_context * context,void (* log)(int level,const char * message,size_t len,void * user_data))436 int signal_context_set_log_function(signal_context *context,
437         void (*log)(int level, const char *message, size_t len, void *user_data))
438 {
439     assert(context);
440     context->log = log;
441     return 0;
442 }
443 
signal_context_destroy(signal_context * context)444 void signal_context_destroy(signal_context *context)
445 {
446 #ifdef DEBUG_REFCOUNT
447     fprintf(stderr, "Global REF count: %d\n", type_ref_count);
448     fprintf(stderr, "Global UNREF count: %d\n", type_unref_count);
449 #endif
450     if(context) {
451         free(context);
452     }
453 }
454 
455 /*------------------------------------------------------------------------*/
456 
signal_crypto_random(signal_context * context,uint8_t * data,size_t len)457 int signal_crypto_random(signal_context *context, uint8_t *data, size_t len)
458 {
459     assert(context);
460     assert(context->crypto_provider.random_func);
461     return context->crypto_provider.random_func(data, len, context->crypto_provider.user_data);
462 }
463 
signal_hmac_sha256_init(signal_context * context,void ** hmac_context,const uint8_t * key,size_t key_len)464 int signal_hmac_sha256_init(signal_context *context, void **hmac_context, const uint8_t *key, size_t key_len)
465 {
466     assert(context);
467     assert(context->crypto_provider.hmac_sha256_init_func);
468     return context->crypto_provider.hmac_sha256_init_func(hmac_context, key, key_len, context->crypto_provider.user_data);
469 }
470 
signal_hmac_sha256_update(signal_context * context,void * hmac_context,const uint8_t * data,size_t data_len)471 int signal_hmac_sha256_update(signal_context *context, void *hmac_context, const uint8_t *data, size_t data_len)
472 {
473     assert(context);
474     assert(context->crypto_provider.hmac_sha256_update_func);
475     return context->crypto_provider.hmac_sha256_update_func(hmac_context, data, data_len, context->crypto_provider.user_data);
476 }
477 
signal_hmac_sha256_final(signal_context * context,void * hmac_context,signal_buffer ** output)478 int signal_hmac_sha256_final(signal_context *context, void *hmac_context, signal_buffer **output)
479 {
480     assert(context);
481     assert(context->crypto_provider.hmac_sha256_final_func);
482     return context->crypto_provider.hmac_sha256_final_func(hmac_context, output, context->crypto_provider.user_data);
483 }
484 
signal_hmac_sha256_cleanup(signal_context * context,void * hmac_context)485 void signal_hmac_sha256_cleanup(signal_context *context, void *hmac_context)
486 {
487     assert(context);
488     assert(context->crypto_provider.hmac_sha256_cleanup_func);
489     context->crypto_provider.hmac_sha256_cleanup_func(hmac_context, context->crypto_provider.user_data);
490 }
491 
signal_sha512_digest_init(signal_context * context,void ** digest_context)492 int signal_sha512_digest_init(signal_context *context, void **digest_context)
493 {
494     assert(context);
495     assert(context->crypto_provider.sha512_digest_init_func);
496     return context->crypto_provider.sha512_digest_init_func(digest_context, context->crypto_provider.user_data);
497 }
498 
signal_sha512_digest_update(signal_context * context,void * digest_context,const uint8_t * data,size_t data_len)499 int signal_sha512_digest_update(signal_context *context, void *digest_context, const uint8_t *data, size_t data_len)
500 {
501     assert(context);
502     assert(context->crypto_provider.sha512_digest_update_func);
503     return context->crypto_provider.sha512_digest_update_func(digest_context, data, data_len, context->crypto_provider.user_data);
504 }
505 
signal_sha512_digest_final(signal_context * context,void * digest_context,signal_buffer ** output)506 int signal_sha512_digest_final(signal_context *context, void *digest_context, signal_buffer **output)
507 {
508     assert(context);
509     assert(context->crypto_provider.sha512_digest_final_func);
510     return context->crypto_provider.sha512_digest_final_func(digest_context, output, context->crypto_provider.user_data);
511 }
512 
signal_sha512_digest_cleanup(signal_context * context,void * digest_context)513 void signal_sha512_digest_cleanup(signal_context *context, void *digest_context)
514 {
515     assert(context);
516     assert(context->crypto_provider.sha512_digest_cleanup_func);
517     return context->crypto_provider.sha512_digest_cleanup_func(digest_context, context->crypto_provider.user_data);
518 }
519 
signal_encrypt(signal_context * context,signal_buffer ** output,int cipher,const uint8_t * key,size_t key_len,const uint8_t * iv,size_t iv_len,const uint8_t * plaintext,size_t plaintext_len)520 int signal_encrypt(signal_context *context,
521         signal_buffer **output,
522         int cipher,
523         const uint8_t *key, size_t key_len,
524         const uint8_t *iv, size_t iv_len,
525         const uint8_t *plaintext, size_t plaintext_len)
526 {
527     assert(context);
528     assert(context->crypto_provider.encrypt_func);
529     return context->crypto_provider.encrypt_func(
530             output, cipher, key, key_len, iv, iv_len,
531             plaintext, plaintext_len,
532             context->crypto_provider.user_data);
533 }
534 
signal_decrypt(signal_context * context,signal_buffer ** output,int cipher,const uint8_t * key,size_t key_len,const uint8_t * iv,size_t iv_len,const uint8_t * ciphertext,size_t ciphertext_len)535 int signal_decrypt(signal_context *context,
536         signal_buffer **output,
537         int cipher,
538         const uint8_t *key, size_t key_len,
539         const uint8_t *iv, size_t iv_len,
540         const uint8_t *ciphertext, size_t ciphertext_len)
541 {
542     assert(context);
543     assert(context->crypto_provider.decrypt_func);
544     return context->crypto_provider.decrypt_func(
545             output, cipher, key, key_len, iv, iv_len,
546             ciphertext, ciphertext_len,
547             context->crypto_provider.user_data);
548 }
549 
signal_lock(signal_context * context)550 void signal_lock(signal_context *context)
551 {
552     if(context->lock) {
553         context->lock(context->user_data);
554     }
555 }
556 
signal_unlock(signal_context * context)557 void signal_unlock(signal_context *context)
558 {
559     if(context->unlock) {
560         context->unlock(context->user_data);
561     }
562 }
563 
signal_log(signal_context * context,int level,const char * format,...)564 void signal_log(signal_context *context, int level, const char *format, ...)
565 {
566     char buf[256];
567     int n;
568     if(context && context->log) {
569         va_list args;
570         va_start(args, format);
571         n = vsnprintf(buf, sizeof(buf), format, args);
572         va_end(args);
573         if(n > 0) {
574             context->log(level, buf, strlen(buf), context->user_data);
575         }
576     }
577 }
578 
signal_explicit_bzero(void * v,size_t n)579 void signal_explicit_bzero(void *v, size_t n)
580 {
581 #ifdef HAVE_SECUREZEROMEMORY
582     SecureZeroMemory(v, n);
583 #elif HAVE_MEMSET_S
584     memset_s(v, n, 0, n);
585 #else
586     volatile unsigned char  *p  =  v;
587     while(n--) *p++ = 0;
588 #endif
589 }
590 
signal_constant_memcmp(const void * s1,const void * s2,size_t n)591 int signal_constant_memcmp(const void *s1, const void *s2, size_t n)
592 {
593     size_t i;
594     const unsigned char *c1 = (const unsigned char *) s1;
595     const unsigned char *c2 = (const unsigned char *) s2;
596     unsigned char result = 0;
597 
598     for (i = 0; i < n; i++) {
599         result |= c1[i] ^ c2[i];
600     }
601 
602     return result;
603 }
604 
signal_protocol_str_serialize_protobuf(ProtobufCBinaryData * buffer,const char * str)605 void signal_protocol_str_serialize_protobuf(ProtobufCBinaryData *buffer, const char *str)
606 {
607     assert(buffer);
608     assert(str);
609     buffer->data = (uint8_t *)str;
610     buffer->len = strlen(str);
611 }
612 
signal_protocol_str_deserialize_protobuf(ProtobufCBinaryData * buffer)613 char *signal_protocol_str_deserialize_protobuf(ProtobufCBinaryData *buffer)
614 {
615     char *str = 0;
616     assert(buffer);
617 
618     str = malloc(buffer->len + 1);
619     if(!str) {
620         return 0;
621     }
622 
623     memcpy(str, buffer->data, buffer->len);
624     str[buffer->len] = '\0';
625 
626     return str;
627 }
628 
629 /*------------------------------------------------------------------------*/
630 
signal_protocol_store_context_create(signal_protocol_store_context ** context,signal_context * global_context)631 int signal_protocol_store_context_create(signal_protocol_store_context **context, signal_context *global_context)
632 {
633     assert(global_context);
634     *context = malloc(sizeof(signal_protocol_store_context));
635     if(!(*context)) {
636         return SG_ERR_NOMEM;
637     }
638     memset(*context, 0, sizeof(signal_protocol_store_context));
639     (*context)->global_context = global_context;
640     return 0;
641 }
642 
signal_protocol_store_context_set_session_store(signal_protocol_store_context * context,const signal_protocol_session_store * store)643 int signal_protocol_store_context_set_session_store(signal_protocol_store_context *context, const signal_protocol_session_store *store)
644 {
645     if(!store) {
646         return SG_ERR_INVAL;
647     }
648     memcpy(&(context->session_store), store, sizeof(signal_protocol_session_store));
649     return 0;
650 }
651 
signal_protocol_store_context_set_pre_key_store(signal_protocol_store_context * context,const signal_protocol_pre_key_store * store)652 int signal_protocol_store_context_set_pre_key_store(signal_protocol_store_context *context, const signal_protocol_pre_key_store *store)
653 {
654     if(!store) {
655         return SG_ERR_INVAL;
656     }
657     memcpy(&(context->pre_key_store), store, sizeof(signal_protocol_pre_key_store));
658     return 0;
659 }
660 
signal_protocol_store_context_set_signed_pre_key_store(signal_protocol_store_context * context,const signal_protocol_signed_pre_key_store * store)661 int signal_protocol_store_context_set_signed_pre_key_store(signal_protocol_store_context *context, const signal_protocol_signed_pre_key_store *store)
662 {
663     if(!store) {
664         return SG_ERR_INVAL;
665     }
666     memcpy(&(context->signed_pre_key_store), store, sizeof(signal_protocol_signed_pre_key_store));
667     return 0;
668 }
669 
signal_protocol_store_context_set_identity_key_store(signal_protocol_store_context * context,const signal_protocol_identity_key_store * store)670 int signal_protocol_store_context_set_identity_key_store(signal_protocol_store_context *context, const signal_protocol_identity_key_store *store)
671 {
672     if(!store) {
673         return SG_ERR_INVAL;
674     }
675     memcpy(&(context->identity_key_store), store, sizeof(signal_protocol_identity_key_store));
676     return 0;
677 }
678 
signal_protocol_store_context_set_sender_key_store(signal_protocol_store_context * context,const signal_protocol_sender_key_store * store)679 int signal_protocol_store_context_set_sender_key_store(signal_protocol_store_context *context, const signal_protocol_sender_key_store *store)
680 {
681     if(!store) {
682         return SG_ERR_INVAL;
683     }
684     memcpy(&(context->sender_key_store), store, sizeof(signal_protocol_sender_key_store));
685     return 0;
686 }
687 
signal_protocol_store_context_destroy(signal_protocol_store_context * context)688 void signal_protocol_store_context_destroy(signal_protocol_store_context *context)
689 {
690     if(context) {
691         if(context->session_store.destroy_func) {
692             context->session_store.destroy_func(context->session_store.user_data);
693         }
694         if(context->pre_key_store.destroy_func) {
695             context->pre_key_store.destroy_func(context->pre_key_store.user_data);
696         }
697         if(context->signed_pre_key_store.destroy_func) {
698             context->signed_pre_key_store.destroy_func(context->signed_pre_key_store.user_data);
699         }
700         if(context->identity_key_store.destroy_func) {
701             context->identity_key_store.destroy_func(context->identity_key_store.user_data);
702         }
703         if(context->sender_key_store.destroy_func) {
704             context->sender_key_store.destroy_func(context->sender_key_store.user_data);
705         }
706         free(context);
707     }
708 }
709 
710 /*------------------------------------------------------------------------*/
711 
signal_protocol_session_load_session(signal_protocol_store_context * context,session_record ** record,const signal_protocol_address * address)712 int signal_protocol_session_load_session(signal_protocol_store_context *context, session_record **record, const signal_protocol_address *address)
713 {
714     int result = 0;
715     signal_buffer *buffer = 0;
716     signal_buffer *user_buffer = 0;
717     session_record *result_record = 0;
718 
719     assert(context);
720     assert(context->session_store.load_session_func);
721 
722     result = context->session_store.load_session_func(
723             &buffer, &user_buffer, address,
724             context->session_store.user_data);
725     if(result < 0) {
726         goto complete;
727     }
728 
729     if(result == 0) {
730         if(buffer) {
731             result = SG_ERR_UNKNOWN;
732             goto complete;
733         }
734         result = session_record_create(&result_record, 0, context->global_context);
735     }
736     else if(result == 1) {
737         if(!buffer) {
738             result = -1;
739             goto complete;
740         }
741         result = session_record_deserialize(&result_record,
742                 signal_buffer_data(buffer), signal_buffer_len(buffer), context->global_context);
743     }
744     else {
745         result = SG_ERR_UNKNOWN;
746     }
747 
748 complete:
749     if(buffer) {
750         signal_buffer_free(buffer);
751     }
752     if(result >= 0) {
753         if(user_buffer) {
754             session_record_set_user_record(result_record, user_buffer);
755         }
756         *record = result_record;
757     }
758     else {
759         signal_buffer_free(user_buffer);
760     }
761     return result;
762 }
763 
signal_protocol_session_get_sub_device_sessions(signal_protocol_store_context * context,signal_int_list ** sessions,const char * name,size_t name_len)764 int signal_protocol_session_get_sub_device_sessions(signal_protocol_store_context *context, signal_int_list **sessions, const char *name, size_t name_len)
765 {
766     assert(context);
767     assert(context->session_store.get_sub_device_sessions_func);
768 
769     return context->session_store.get_sub_device_sessions_func(
770             sessions, name, name_len,
771             context->session_store.user_data);
772 }
773 
signal_protocol_session_store_session(signal_protocol_store_context * context,const signal_protocol_address * address,session_record * record)774 int signal_protocol_session_store_session(signal_protocol_store_context *context, const signal_protocol_address *address, session_record *record)
775 {
776     int result = 0;
777     signal_buffer *buffer = 0;
778     signal_buffer *user_buffer = 0;
779     uint8_t *user_buffer_data = 0;
780     size_t user_buffer_len = 0;
781 
782     assert(context);
783     assert(context->session_store.store_session_func);
784     assert(record);
785 
786     result = session_record_serialize(&buffer, record);
787     if(result < 0) {
788         goto complete;
789     }
790 
791     user_buffer = session_record_get_user_record(record);
792     if(user_buffer) {
793         user_buffer_data = signal_buffer_data(user_buffer);
794         user_buffer_len = signal_buffer_len(user_buffer);
795     }
796 
797     result = context->session_store.store_session_func(
798             address,
799             signal_buffer_data(buffer), signal_buffer_len(buffer),
800             user_buffer_data, user_buffer_len,
801             context->session_store.user_data);
802 
803 complete:
804     if(buffer) {
805         signal_buffer_free(buffer);
806     }
807 
808     return result;
809 }
810 
signal_protocol_session_contains_session(signal_protocol_store_context * context,const signal_protocol_address * address)811 int signal_protocol_session_contains_session(signal_protocol_store_context *context, const signal_protocol_address *address)
812 {
813     assert(context);
814     assert(context->session_store.contains_session_func);
815 
816     return context->session_store.contains_session_func(
817             address,
818             context->session_store.user_data);
819 }
820 
signal_protocol_session_delete_session(signal_protocol_store_context * context,const signal_protocol_address * address)821 int signal_protocol_session_delete_session(signal_protocol_store_context *context, const signal_protocol_address *address)
822 {
823     assert(context);
824     assert(context->session_store.delete_session_func);
825 
826     return context->session_store.delete_session_func(
827             address,
828             context->session_store.user_data);
829 }
830 
signal_protocol_session_delete_all_sessions(signal_protocol_store_context * context,const char * name,size_t name_len)831 int signal_protocol_session_delete_all_sessions(signal_protocol_store_context *context, const char *name, size_t name_len)
832 {
833     assert(context);
834     assert(context->session_store.delete_all_sessions_func);
835 
836     return context->session_store.delete_all_sessions_func(
837             name, name_len,
838             context->session_store.user_data);
839 }
840 
841 /*------------------------------------------------------------------------*/
842 
signal_protocol_pre_key_load_key(signal_protocol_store_context * context,session_pre_key ** pre_key,uint32_t pre_key_id)843 int signal_protocol_pre_key_load_key(signal_protocol_store_context *context, session_pre_key **pre_key, uint32_t pre_key_id)
844 {
845     int result = 0;
846     signal_buffer *buffer = 0;
847     session_pre_key *result_key = 0;
848 
849     assert(context);
850     assert(context->pre_key_store.load_pre_key);
851 
852     result = context->pre_key_store.load_pre_key(
853             &buffer, pre_key_id,
854             context->pre_key_store.user_data);
855     if(result < 0) {
856         goto complete;
857     }
858 
859     result = session_pre_key_deserialize(&result_key,
860             signal_buffer_data(buffer), signal_buffer_len(buffer), context->global_context);
861 
862 complete:
863     if(buffer) {
864         signal_buffer_free(buffer);
865     }
866     if(result >= 0) {
867         *pre_key = result_key;
868     }
869     return result;
870 }
871 
signal_protocol_pre_key_store_key(signal_protocol_store_context * context,session_pre_key * pre_key)872 int signal_protocol_pre_key_store_key(signal_protocol_store_context *context, session_pre_key *pre_key)
873 {
874     int result = 0;
875     signal_buffer *buffer = 0;
876     uint32_t id = 0;
877 
878     assert(context);
879     assert(context->pre_key_store.store_pre_key);
880     assert(pre_key);
881 
882     id = session_pre_key_get_id(pre_key);
883 
884     result = session_pre_key_serialize(&buffer, pre_key);
885     if(result < 0) {
886         goto complete;
887     }
888 
889     result = context->pre_key_store.store_pre_key(
890             id,
891             signal_buffer_data(buffer), signal_buffer_len(buffer),
892             context->pre_key_store.user_data);
893 
894 complete:
895     if(buffer) {
896         signal_buffer_free(buffer);
897     }
898 
899     return result;
900 }
901 
signal_protocol_pre_key_contains_key(signal_protocol_store_context * context,uint32_t pre_key_id)902 int signal_protocol_pre_key_contains_key(signal_protocol_store_context *context, uint32_t pre_key_id)
903 {
904     int result = 0;
905 
906     assert(context);
907     assert(context->pre_key_store.contains_pre_key);
908 
909     result = context->pre_key_store.contains_pre_key(
910             pre_key_id, context->pre_key_store.user_data);
911 
912     return result;
913 }
914 
signal_protocol_pre_key_remove_key(signal_protocol_store_context * context,uint32_t pre_key_id)915 int signal_protocol_pre_key_remove_key(signal_protocol_store_context *context, uint32_t pre_key_id)
916 {
917     int result = 0;
918 
919     assert(context);
920     assert(context->pre_key_store.remove_pre_key);
921 
922     result = context->pre_key_store.remove_pre_key(
923             pre_key_id, context->pre_key_store.user_data);
924 
925     return result;
926 }
927 
928 /*------------------------------------------------------------------------*/
929 
signal_protocol_signed_pre_key_load_key(signal_protocol_store_context * context,session_signed_pre_key ** pre_key,uint32_t signed_pre_key_id)930 int signal_protocol_signed_pre_key_load_key(signal_protocol_store_context *context, session_signed_pre_key **pre_key, uint32_t signed_pre_key_id)
931 {
932     int result = 0;
933     signal_buffer *buffer = 0;
934     session_signed_pre_key *result_key = 0;
935 
936     assert(context);
937     assert(context->signed_pre_key_store.load_signed_pre_key);
938 
939     result = context->signed_pre_key_store.load_signed_pre_key(
940             &buffer, signed_pre_key_id,
941             context->signed_pre_key_store.user_data);
942     if(result < 0) {
943         goto complete;
944     }
945 
946     result = session_signed_pre_key_deserialize(&result_key,
947             signal_buffer_data(buffer), signal_buffer_len(buffer), context->global_context);
948 
949 complete:
950     if(buffer) {
951         signal_buffer_free(buffer);
952     }
953     if(result >= 0) {
954         *pre_key = result_key;
955     }
956     return result;
957 }
958 
signal_protocol_signed_pre_key_store_key(signal_protocol_store_context * context,session_signed_pre_key * pre_key)959 int signal_protocol_signed_pre_key_store_key(signal_protocol_store_context *context, session_signed_pre_key *pre_key)
960 {
961     int result = 0;
962     signal_buffer *buffer = 0;
963     uint32_t id = 0;
964 
965     assert(context);
966     assert(context->signed_pre_key_store.store_signed_pre_key);
967     assert(pre_key);
968 
969     id = session_signed_pre_key_get_id(pre_key);
970 
971     result = session_signed_pre_key_serialize(&buffer, pre_key);
972     if(result < 0) {
973         goto complete;
974     }
975 
976     result = context->signed_pre_key_store.store_signed_pre_key(
977             id,
978             signal_buffer_data(buffer), signal_buffer_len(buffer),
979             context->signed_pre_key_store.user_data);
980 
981 complete:
982     if(buffer) {
983         signal_buffer_free(buffer);
984     }
985 
986     return result;
987 }
988 
signal_protocol_signed_pre_key_contains_key(signal_protocol_store_context * context,uint32_t signed_pre_key_id)989 int signal_protocol_signed_pre_key_contains_key(signal_protocol_store_context *context, uint32_t signed_pre_key_id)
990 {
991     int result = 0;
992 
993     assert(context);
994     assert(context->signed_pre_key_store.contains_signed_pre_key);
995 
996     result = context->signed_pre_key_store.contains_signed_pre_key(
997             signed_pre_key_id, context->signed_pre_key_store.user_data);
998 
999     return result;
1000 }
1001 
signal_protocol_signed_pre_key_remove_key(signal_protocol_store_context * context,uint32_t signed_pre_key_id)1002 int signal_protocol_signed_pre_key_remove_key(signal_protocol_store_context *context, uint32_t signed_pre_key_id)
1003 {
1004     int result = 0;
1005 
1006     assert(context);
1007     assert(context->signed_pre_key_store.remove_signed_pre_key);
1008 
1009     result = context->signed_pre_key_store.remove_signed_pre_key(
1010             signed_pre_key_id, context->signed_pre_key_store.user_data);
1011 
1012     return result;
1013 }
1014 
1015 /*------------------------------------------------------------------------*/
1016 
signal_protocol_identity_get_key_pair(signal_protocol_store_context * context,ratchet_identity_key_pair ** key_pair)1017 int signal_protocol_identity_get_key_pair(signal_protocol_store_context *context, ratchet_identity_key_pair **key_pair)
1018 {
1019     int result = 0;
1020     signal_buffer *public_buf = 0;
1021     signal_buffer *private_buf = 0;
1022     ec_public_key *public_key = 0;
1023     ec_private_key *private_key = 0;
1024     ratchet_identity_key_pair *result_key = 0;
1025 
1026     assert(context);
1027     assert(context->identity_key_store.get_identity_key_pair);
1028 
1029     result = context->identity_key_store.get_identity_key_pair(
1030             &public_buf, &private_buf,
1031             context->identity_key_store.user_data);
1032     if(result < 0) {
1033         goto complete;
1034     }
1035 
1036     result = curve_decode_point(&public_key, public_buf->data, public_buf->len, context->global_context);
1037     if(result < 0) {
1038         goto complete;
1039     }
1040 
1041     result = curve_decode_private_point(&private_key, private_buf->data, private_buf->len, context->global_context);
1042     if(result < 0) {
1043         goto complete;
1044     }
1045 
1046     result = ratchet_identity_key_pair_create(&result_key, public_key, private_key);
1047     if(result < 0) {
1048         goto complete;
1049     }
1050 
1051 complete:
1052     if(public_buf) {
1053         signal_buffer_free(public_buf);
1054     }
1055     if(private_buf) {
1056         signal_buffer_free(private_buf);
1057     }
1058     if(public_key) {
1059         SIGNAL_UNREF(public_key);
1060     }
1061     if(private_key) {
1062         SIGNAL_UNREF(private_key);
1063     }
1064     if(result >= 0) {
1065         *key_pair = result_key;
1066     }
1067     return result;
1068 }
1069 
signal_protocol_identity_get_local_registration_id(signal_protocol_store_context * context,uint32_t * registration_id)1070 int signal_protocol_identity_get_local_registration_id(signal_protocol_store_context *context, uint32_t *registration_id)
1071 {
1072     int result = 0;
1073 
1074     assert(context);
1075     assert(context->identity_key_store.get_local_registration_id);
1076 
1077     result = context->identity_key_store.get_local_registration_id(
1078             context->identity_key_store.user_data, registration_id);
1079 
1080     return result;
1081 }
1082 
signal_protocol_identity_save_identity(signal_protocol_store_context * context,const signal_protocol_address * address,ec_public_key * identity_key)1083 int signal_protocol_identity_save_identity(signal_protocol_store_context *context, const signal_protocol_address *address, ec_public_key *identity_key)
1084 {
1085     int result = 0;
1086     signal_buffer *buffer = 0;
1087 
1088     assert(context);
1089     assert(context->identity_key_store.save_identity);
1090 
1091     if(identity_key) {
1092         result = ec_public_key_serialize(&buffer, identity_key);
1093         if(result < 0) {
1094             goto complete;
1095         }
1096 
1097         result = context->identity_key_store.save_identity(
1098                 address,
1099                 signal_buffer_data(buffer),
1100                 signal_buffer_len(buffer),
1101                 context->identity_key_store.user_data);
1102     }
1103     else {
1104         result = context->identity_key_store.save_identity(
1105                 address, 0, 0,
1106                 context->identity_key_store.user_data);
1107     }
1108 
1109 complete:
1110     if(buffer) {
1111         signal_buffer_free(buffer);
1112     }
1113 
1114     return result;
1115 }
1116 
signal_protocol_identity_is_trusted_identity(signal_protocol_store_context * context,const signal_protocol_address * address,ec_public_key * identity_key)1117 int signal_protocol_identity_is_trusted_identity(signal_protocol_store_context *context, const signal_protocol_address *address, ec_public_key *identity_key)
1118 {
1119     int result = 0;
1120     signal_buffer *buffer = 0;
1121 
1122     assert(context);
1123     assert(context->identity_key_store.is_trusted_identity);
1124 
1125     result = ec_public_key_serialize(&buffer, identity_key);
1126     if(result < 0) {
1127         goto complete;
1128     }
1129 
1130     result = context->identity_key_store.is_trusted_identity(
1131             address,
1132             signal_buffer_data(buffer),
1133             signal_buffer_len(buffer),
1134             context->identity_key_store.user_data);
1135 complete:
1136     if(buffer) {
1137         signal_buffer_free(buffer);
1138     }
1139 
1140     return result;
1141 }
1142 
signal_protocol_sender_key_store_key(signal_protocol_store_context * context,const signal_protocol_sender_key_name * sender_key_name,sender_key_record * record)1143 int signal_protocol_sender_key_store_key(signal_protocol_store_context *context, const signal_protocol_sender_key_name *sender_key_name, sender_key_record *record)
1144 {
1145     int result = 0;
1146     signal_buffer *buffer = 0;
1147     signal_buffer *user_buffer = 0;
1148     uint8_t *user_buffer_data = 0;
1149     size_t user_buffer_len = 0;
1150 
1151     assert(context);
1152     assert(context->sender_key_store.store_sender_key);
1153     assert(record);
1154 
1155     result = sender_key_record_serialize(&buffer, record);
1156     if(result < 0) {
1157         goto complete;
1158     }
1159 
1160     user_buffer = sender_key_record_get_user_record(record);
1161     if(user_buffer) {
1162         user_buffer_data = signal_buffer_data(user_buffer);
1163         user_buffer_len = signal_buffer_len(user_buffer);
1164     }
1165 
1166     result = context->sender_key_store.store_sender_key(
1167             sender_key_name,
1168             signal_buffer_data(buffer), signal_buffer_len(buffer),
1169             user_buffer_data, user_buffer_len,
1170             context->sender_key_store.user_data);
1171 
1172 complete:
1173     if(buffer) {
1174         signal_buffer_free(buffer);
1175     }
1176 
1177     return result;
1178 }
1179 
signal_protocol_sender_key_load_key(signal_protocol_store_context * context,sender_key_record ** record,const signal_protocol_sender_key_name * sender_key_name)1180 int signal_protocol_sender_key_load_key(signal_protocol_store_context *context, sender_key_record **record, const signal_protocol_sender_key_name *sender_key_name)
1181 {
1182     int result = 0;
1183     signal_buffer *buffer = 0;
1184     signal_buffer *user_buffer = 0;
1185     sender_key_record *result_record = 0;
1186 
1187     assert(context);
1188     assert(context->sender_key_store.load_sender_key);
1189 
1190     result = context->sender_key_store.load_sender_key(
1191             &buffer, &user_buffer, sender_key_name,
1192             context->sender_key_store.user_data);
1193     if(result < 0) {
1194         goto complete;
1195     }
1196 
1197     if(result == 0) {
1198         if(buffer) {
1199             result = SG_ERR_UNKNOWN;
1200             goto complete;
1201         }
1202         result = sender_key_record_create(&result_record, context->global_context);
1203     }
1204     else if(result == 1) {
1205         if(!buffer) {
1206             result = -1;
1207             goto complete;
1208         }
1209         result = sender_key_record_deserialize(&result_record,
1210                 signal_buffer_data(buffer), signal_buffer_len(buffer), context->global_context);
1211     }
1212     else {
1213         result = SG_ERR_UNKNOWN;
1214     }
1215 
1216 complete:
1217     if(buffer) {
1218         signal_buffer_free(buffer);
1219     }
1220     if(result >= 0) {
1221         if(user_buffer) {
1222             sender_key_record_set_user_record(result_record, user_buffer);
1223         }
1224         *record = result_record;
1225     }
1226     else {
1227         signal_buffer_free(user_buffer);
1228     }
1229     return result;
1230 }
1231