1 #include "signal_protocol.h"
2
3 #include <stdio.h>
4 #include <stdlib.h>
5 #include <stdint.h>
6 #include <string.h>
7 #include <stdarg.h>
8 #include <assert.h>
9
10 #include "signal_protocol_internal.h"
11 #include "signal_utarray.h"
12
13 #ifdef _WINDOWS
14 #include "Windows.h"
15 #include "WinBase.h"
16 #endif
17
18 #ifdef DEBUG_REFCOUNT
19 int type_ref_count = 0;
20 int type_unref_count = 0;
21 #endif
22
23 #define MIN(a,b) (((a)<(b))?(a):(b))
24
25 struct signal_protocol_store_context {
26 signal_context *global_context;
27 signal_protocol_session_store session_store;
28 signal_protocol_pre_key_store pre_key_store;
29 signal_protocol_signed_pre_key_store signed_pre_key_store;
30 signal_protocol_identity_key_store identity_key_store;
31 signal_protocol_sender_key_store sender_key_store;
32 };
33
signal_type_init(signal_type_base * instance,void (* destroy_func)(signal_type_base * instance))34 void signal_type_init(signal_type_base *instance,
35 void (*destroy_func)(signal_type_base *instance))
36 {
37 instance->ref_count = 1;
38 instance->destroy = destroy_func;
39 #ifdef DEBUG_REFCOUNT
40 type_ref_count++;
41 #endif
42 }
43
signal_type_ref(signal_type_base * instance)44 void signal_type_ref(signal_type_base *instance)
45 {
46 #ifdef DEBUG_REFCOUNT
47 type_ref_count++;
48 #endif
49 assert(instance);
50 assert(instance->ref_count > 0);
51 instance->ref_count++;
52 }
53
signal_type_unref(signal_type_base * instance)54 void signal_type_unref(signal_type_base *instance)
55 {
56 if(instance) {
57 #ifdef DEBUG_REFCOUNT
58 type_unref_count++;
59 #endif
60 assert(instance->ref_count > 0);
61 if(instance->ref_count > 1) {
62 instance->ref_count--;
63 }
64 else {
65 instance->destroy(instance);
66 }
67 }
68 }
69
70 #ifdef DEBUG_REFCOUNT
signal_type_ref_count(signal_type_base * instance)71 int signal_type_ref_count(signal_type_base *instance)
72 {
73 if(!instance) {
74 return 0;
75 }
76 return instance->ref_count;
77 }
78 #endif
79
80 /*------------------------------------------------------------------------*/
81
signal_buffer_alloc(size_t len)82 signal_buffer *signal_buffer_alloc(size_t len)
83 {
84 signal_buffer *buffer;
85 if(len > (SIZE_MAX - sizeof(struct signal_buffer)) / sizeof(uint8_t)) {
86 return 0;
87 }
88
89 buffer = malloc(sizeof(struct signal_buffer) + (sizeof(uint8_t) * len));
90 if(buffer) {
91 buffer->len = len;
92 }
93 return buffer;
94 }
95
signal_buffer_create(const uint8_t * data,size_t len)96 signal_buffer *signal_buffer_create(const uint8_t *data, size_t len)
97 {
98 signal_buffer *buffer = signal_buffer_alloc(len);
99 if(!buffer) {
100 return 0;
101 }
102
103 memcpy(buffer->data, data, len);
104 return buffer;
105 }
106
signal_buffer_copy(const signal_buffer * buffer)107 signal_buffer *signal_buffer_copy(const signal_buffer *buffer)
108 {
109 return signal_buffer_create(buffer->data, buffer->len);
110 }
111
signal_buffer_n_copy(const signal_buffer * buffer,size_t n)112 signal_buffer *signal_buffer_n_copy(const signal_buffer *buffer, size_t n)
113 {
114 size_t len = MIN(buffer->len, n);
115 return signal_buffer_create(buffer->data, len);
116 }
117
signal_buffer_append(signal_buffer * buffer,const uint8_t * data,size_t len)118 signal_buffer *signal_buffer_append(signal_buffer *buffer, const uint8_t *data, size_t len)
119 {
120 signal_buffer *tmp_buffer;
121 size_t previous_size = buffer->len;
122 size_t previous_alloc = sizeof(struct signal_buffer) + (sizeof(uint8_t) * previous_size);
123
124 if(len > (SIZE_MAX - previous_alloc)) {
125 return 0;
126 }
127
128 tmp_buffer = realloc(buffer, previous_alloc + (sizeof(uint8_t) * len));
129 if(!tmp_buffer) {
130 return 0;
131 }
132
133 memcpy(tmp_buffer->data + previous_size, data, len);
134 tmp_buffer->len = previous_size + len;
135 return tmp_buffer;
136 }
137
signal_buffer_data(signal_buffer * buffer)138 uint8_t *signal_buffer_data(signal_buffer *buffer)
139 {
140 return buffer->data;
141 }
142
signal_buffer_const_data(const signal_buffer * buffer)143 const uint8_t *signal_buffer_const_data(const signal_buffer *buffer)
144 {
145 return buffer->data;
146 }
147
signal_buffer_len(const signal_buffer * buffer)148 size_t signal_buffer_len(const signal_buffer *buffer)
149 {
150 return buffer->len;
151 }
152
signal_buffer_compare(signal_buffer * buffer1,signal_buffer * buffer2)153 int signal_buffer_compare(signal_buffer *buffer1, signal_buffer *buffer2)
154 {
155 if(buffer1 == buffer2) {
156 return 0;
157 }
158 else if(buffer1 == 0 && buffer2 != 0) {
159 return -1;
160 }
161 else if(buffer1 != 0 && buffer2 == 0) {
162 return 1;
163 }
164 else {
165 if(buffer1->len < buffer2->len) {
166 return -1;
167 }
168 else if(buffer1->len > buffer2->len) {
169 return 1;
170 }
171 else {
172 return signal_constant_memcmp(buffer1->data, buffer2->data, buffer1->len);
173 }
174 }
175 }
176
signal_buffer_free(signal_buffer * buffer)177 void signal_buffer_free(signal_buffer *buffer)
178 {
179 if(buffer) {
180 free(buffer);
181 }
182 }
183
signal_buffer_bzero_free(signal_buffer * buffer)184 void signal_buffer_bzero_free(signal_buffer *buffer)
185 {
186 if(buffer) {
187 signal_explicit_bzero(buffer->data, buffer->len);
188 free(buffer);
189 }
190 }
191
192 /*------------------------------------------------------------------------*/
193
194 struct signal_buffer_list
195 {
196 UT_array *values;
197 };
198
signal_buffer_list_alloc(void)199 signal_buffer_list *signal_buffer_list_alloc(void)
200 {
201 int result = 0;
202 signal_buffer_list *list = malloc(sizeof(signal_buffer_list));
203 if(!list) {
204 result = SG_ERR_NOMEM;
205 goto complete;
206 }
207
208 memset(list, 0, sizeof(signal_buffer_list));
209
210 utarray_new(list->values, &ut_ptr_icd);
211
212 complete:
213 if(result < 0) {
214 if(list) {
215 free(list);
216 }
217 return 0;
218 }
219 else {
220 return list;
221 }
222 }
223
signal_buffer_list_copy(const signal_buffer_list * list)224 signal_buffer_list *signal_buffer_list_copy(const signal_buffer_list *list)
225 {
226 int result = 0;
227 signal_buffer_list *result_list = 0;
228 signal_buffer *buffer_copy = 0;
229 unsigned int list_size;
230 unsigned int i;
231
232 result_list = signal_buffer_list_alloc();
233 if(!result_list) {
234 result = SG_ERR_NOMEM;
235 goto complete;
236 }
237
238 list_size = utarray_len(list->values);
239
240 utarray_reserve(result_list->values, list_size);
241
242 for(i = 0; i < list_size; i++) {
243 signal_buffer **buffer = (signal_buffer**)utarray_eltptr(list->values, i);
244 buffer_copy = signal_buffer_copy(*buffer);
245 utarray_push_back(result_list->values, &buffer_copy);
246 buffer_copy = 0;
247 }
248
249 complete:
250 if(result < 0) {
251 signal_buffer_free(buffer_copy);
252 signal_buffer_list_free(result_list);
253 return 0;
254 }
255 else {
256 return result_list;
257 }
258 }
259
signal_buffer_list_push_back(signal_buffer_list * list,signal_buffer * buffer)260 int signal_buffer_list_push_back(signal_buffer_list *list, signal_buffer *buffer)
261 {
262 int result = 0;
263 assert(list);
264 utarray_push_back(list->values, &buffer);
265
266 complete:
267 return result;
268 }
269
signal_buffer_list_size(signal_buffer_list * list)270 unsigned int signal_buffer_list_size(signal_buffer_list *list)
271 {
272 assert(list);
273 return utarray_len(list->values);
274 }
275
signal_buffer_list_at(signal_buffer_list * list,unsigned int index)276 signal_buffer *signal_buffer_list_at(signal_buffer_list *list, unsigned int index)
277 {
278 signal_buffer **value = 0;
279
280 assert(list);
281 assert(index < utarray_len(list->values));
282
283 value = (signal_buffer**)utarray_eltptr(list->values, index);
284
285 assert(*value);
286
287 return *value;
288 }
289
signal_buffer_list_free(signal_buffer_list * list)290 void signal_buffer_list_free(signal_buffer_list *list)
291 {
292 unsigned int size;
293 unsigned int i;
294 signal_buffer **p;
295 if(list) {
296 size = utarray_len(list->values);
297 for (i = 0; i < size; i++) {
298 p = (signal_buffer **)utarray_eltptr(list->values, i);
299 signal_buffer_free(*p);
300 }
301 utarray_free(list->values);
302 free(list);
303 }
304 }
305
signal_buffer_list_bzero_free(signal_buffer_list * list)306 void signal_buffer_list_bzero_free(signal_buffer_list *list)
307 {
308 unsigned int size;
309 unsigned int i;
310 signal_buffer **p;
311 if(list) {
312 size = utarray_len(list->values);
313 for (i = 0; i < size; i++) {
314 p = (signal_buffer **)utarray_eltptr(list->values, i);
315 signal_buffer_bzero_free(*p);
316 }
317 utarray_free(list->values);
318 free(list);
319 }
320 }
321
322 /*------------------------------------------------------------------------*/
323
324 struct signal_int_list
325 {
326 UT_array *values;
327 };
328
signal_int_list_alloc()329 signal_int_list *signal_int_list_alloc()
330 {
331 int result = 0;
332 signal_int_list *list = malloc(sizeof(signal_int_list));
333 if(!list) {
334 result = SG_ERR_NOMEM;
335 goto complete;
336 }
337
338 memset(list, 0, sizeof(signal_int_list));
339
340 utarray_new(list->values, &ut_int_icd);
341
342 complete:
343 if(result < 0) {
344 if(list) {
345 free(list);
346 }
347 return 0;
348 }
349 else {
350 return list;
351 }
352 }
353
signal_int_list_push_back(signal_int_list * list,int value)354 int signal_int_list_push_back(signal_int_list *list, int value)
355 {
356 int result = 0;
357 assert(list);
358 utarray_push_back(list->values, &value);
359
360 complete:
361 return result;
362 }
363
signal_int_list_size(signal_int_list * list)364 unsigned int signal_int_list_size(signal_int_list *list)
365 {
366 assert(list);
367 return utarray_len(list->values);
368 }
369
signal_int_list_at(signal_int_list * list,unsigned int index)370 int signal_int_list_at(signal_int_list *list, unsigned int index)
371 {
372 int *value = 0;
373
374 assert(list);
375 assert(index < utarray_len(list->values));
376
377 value = (int *)utarray_eltptr(list->values, index);
378
379 assert(value);
380
381 return *value;
382 }
383
signal_int_list_free(signal_int_list * list)384 void signal_int_list_free(signal_int_list *list)
385 {
386 if(list) {
387 utarray_free(list->values);
388 free(list);
389 }
390 }
391
392 /*------------------------------------------------------------------------*/
393
signal_context_create(signal_context ** context,void * user_data)394 int signal_context_create(signal_context **context, void *user_data)
395 {
396 *context = malloc(sizeof(signal_context));
397 if(!(*context)) {
398 return SG_ERR_NOMEM;
399 }
400 memset(*context, 0, sizeof(signal_context));
401 (*context)->user_data = user_data;
402 #ifdef DEBUG_REFCOUNT
403 type_ref_count = 0;
404 type_unref_count = 0;
405 #endif
406 return 0;
407 }
408
signal_context_set_crypto_provider(signal_context * context,const signal_crypto_provider * crypto_provider)409 int signal_context_set_crypto_provider(signal_context *context, const signal_crypto_provider *crypto_provider)
410 {
411 assert(context);
412 if(!crypto_provider
413 || !crypto_provider->hmac_sha256_init_func
414 || !crypto_provider->hmac_sha256_update_func
415 || !crypto_provider->hmac_sha256_final_func
416 || !crypto_provider->hmac_sha256_cleanup_func) {
417 return SG_ERR_INVAL;
418 }
419 memcpy(&(context->crypto_provider), crypto_provider, sizeof(signal_crypto_provider));
420 return 0;
421 }
422
signal_context_set_locking_functions(signal_context * context,void (* lock)(void * user_data),void (* unlock)(void * user_data))423 int signal_context_set_locking_functions(signal_context *context,
424 void (*lock)(void *user_data), void (*unlock)(void *user_data))
425 {
426 assert(context);
427 if((lock && !unlock) || (!lock && unlock)) {
428 return SG_ERR_INVAL;
429 }
430
431 context->lock = lock;
432 context->unlock = unlock;
433 return 0;
434 }
435
signal_context_set_log_function(signal_context * context,void (* log)(int level,const char * message,size_t len,void * user_data))436 int signal_context_set_log_function(signal_context *context,
437 void (*log)(int level, const char *message, size_t len, void *user_data))
438 {
439 assert(context);
440 context->log = log;
441 return 0;
442 }
443
signal_context_destroy(signal_context * context)444 void signal_context_destroy(signal_context *context)
445 {
446 #ifdef DEBUG_REFCOUNT
447 fprintf(stderr, "Global REF count: %d\n", type_ref_count);
448 fprintf(stderr, "Global UNREF count: %d\n", type_unref_count);
449 #endif
450 if(context) {
451 free(context);
452 }
453 }
454
455 /*------------------------------------------------------------------------*/
456
signal_crypto_random(signal_context * context,uint8_t * data,size_t len)457 int signal_crypto_random(signal_context *context, uint8_t *data, size_t len)
458 {
459 assert(context);
460 assert(context->crypto_provider.random_func);
461 return context->crypto_provider.random_func(data, len, context->crypto_provider.user_data);
462 }
463
signal_hmac_sha256_init(signal_context * context,void ** hmac_context,const uint8_t * key,size_t key_len)464 int signal_hmac_sha256_init(signal_context *context, void **hmac_context, const uint8_t *key, size_t key_len)
465 {
466 assert(context);
467 assert(context->crypto_provider.hmac_sha256_init_func);
468 return context->crypto_provider.hmac_sha256_init_func(hmac_context, key, key_len, context->crypto_provider.user_data);
469 }
470
signal_hmac_sha256_update(signal_context * context,void * hmac_context,const uint8_t * data,size_t data_len)471 int signal_hmac_sha256_update(signal_context *context, void *hmac_context, const uint8_t *data, size_t data_len)
472 {
473 assert(context);
474 assert(context->crypto_provider.hmac_sha256_update_func);
475 return context->crypto_provider.hmac_sha256_update_func(hmac_context, data, data_len, context->crypto_provider.user_data);
476 }
477
signal_hmac_sha256_final(signal_context * context,void * hmac_context,signal_buffer ** output)478 int signal_hmac_sha256_final(signal_context *context, void *hmac_context, signal_buffer **output)
479 {
480 assert(context);
481 assert(context->crypto_provider.hmac_sha256_final_func);
482 return context->crypto_provider.hmac_sha256_final_func(hmac_context, output, context->crypto_provider.user_data);
483 }
484
signal_hmac_sha256_cleanup(signal_context * context,void * hmac_context)485 void signal_hmac_sha256_cleanup(signal_context *context, void *hmac_context)
486 {
487 assert(context);
488 assert(context->crypto_provider.hmac_sha256_cleanup_func);
489 context->crypto_provider.hmac_sha256_cleanup_func(hmac_context, context->crypto_provider.user_data);
490 }
491
signal_sha512_digest_init(signal_context * context,void ** digest_context)492 int signal_sha512_digest_init(signal_context *context, void **digest_context)
493 {
494 assert(context);
495 assert(context->crypto_provider.sha512_digest_init_func);
496 return context->crypto_provider.sha512_digest_init_func(digest_context, context->crypto_provider.user_data);
497 }
498
signal_sha512_digest_update(signal_context * context,void * digest_context,const uint8_t * data,size_t data_len)499 int signal_sha512_digest_update(signal_context *context, void *digest_context, const uint8_t *data, size_t data_len)
500 {
501 assert(context);
502 assert(context->crypto_provider.sha512_digest_update_func);
503 return context->crypto_provider.sha512_digest_update_func(digest_context, data, data_len, context->crypto_provider.user_data);
504 }
505
signal_sha512_digest_final(signal_context * context,void * digest_context,signal_buffer ** output)506 int signal_sha512_digest_final(signal_context *context, void *digest_context, signal_buffer **output)
507 {
508 assert(context);
509 assert(context->crypto_provider.sha512_digest_final_func);
510 return context->crypto_provider.sha512_digest_final_func(digest_context, output, context->crypto_provider.user_data);
511 }
512
signal_sha512_digest_cleanup(signal_context * context,void * digest_context)513 void signal_sha512_digest_cleanup(signal_context *context, void *digest_context)
514 {
515 assert(context);
516 assert(context->crypto_provider.sha512_digest_cleanup_func);
517 return context->crypto_provider.sha512_digest_cleanup_func(digest_context, context->crypto_provider.user_data);
518 }
519
signal_encrypt(signal_context * context,signal_buffer ** output,int cipher,const uint8_t * key,size_t key_len,const uint8_t * iv,size_t iv_len,const uint8_t * plaintext,size_t plaintext_len)520 int signal_encrypt(signal_context *context,
521 signal_buffer **output,
522 int cipher,
523 const uint8_t *key, size_t key_len,
524 const uint8_t *iv, size_t iv_len,
525 const uint8_t *plaintext, size_t plaintext_len)
526 {
527 assert(context);
528 assert(context->crypto_provider.encrypt_func);
529 return context->crypto_provider.encrypt_func(
530 output, cipher, key, key_len, iv, iv_len,
531 plaintext, plaintext_len,
532 context->crypto_provider.user_data);
533 }
534
signal_decrypt(signal_context * context,signal_buffer ** output,int cipher,const uint8_t * key,size_t key_len,const uint8_t * iv,size_t iv_len,const uint8_t * ciphertext,size_t ciphertext_len)535 int signal_decrypt(signal_context *context,
536 signal_buffer **output,
537 int cipher,
538 const uint8_t *key, size_t key_len,
539 const uint8_t *iv, size_t iv_len,
540 const uint8_t *ciphertext, size_t ciphertext_len)
541 {
542 assert(context);
543 assert(context->crypto_provider.decrypt_func);
544 return context->crypto_provider.decrypt_func(
545 output, cipher, key, key_len, iv, iv_len,
546 ciphertext, ciphertext_len,
547 context->crypto_provider.user_data);
548 }
549
signal_lock(signal_context * context)550 void signal_lock(signal_context *context)
551 {
552 if(context->lock) {
553 context->lock(context->user_data);
554 }
555 }
556
signal_unlock(signal_context * context)557 void signal_unlock(signal_context *context)
558 {
559 if(context->unlock) {
560 context->unlock(context->user_data);
561 }
562 }
563
signal_log(signal_context * context,int level,const char * format,...)564 void signal_log(signal_context *context, int level, const char *format, ...)
565 {
566 char buf[256];
567 int n;
568 if(context && context->log) {
569 va_list args;
570 va_start(args, format);
571 n = vsnprintf(buf, sizeof(buf), format, args);
572 va_end(args);
573 if(n > 0) {
574 context->log(level, buf, strlen(buf), context->user_data);
575 }
576 }
577 }
578
signal_explicit_bzero(void * v,size_t n)579 void signal_explicit_bzero(void *v, size_t n)
580 {
581 #ifdef HAVE_SECUREZEROMEMORY
582 SecureZeroMemory(v, n);
583 #elif HAVE_MEMSET_S
584 memset_s(v, n, 0, n);
585 #else
586 volatile unsigned char *p = v;
587 while(n--) *p++ = 0;
588 #endif
589 }
590
signal_constant_memcmp(const void * s1,const void * s2,size_t n)591 int signal_constant_memcmp(const void *s1, const void *s2, size_t n)
592 {
593 size_t i;
594 const unsigned char *c1 = (const unsigned char *) s1;
595 const unsigned char *c2 = (const unsigned char *) s2;
596 unsigned char result = 0;
597
598 for (i = 0; i < n; i++) {
599 result |= c1[i] ^ c2[i];
600 }
601
602 return result;
603 }
604
signal_protocol_str_serialize_protobuf(ProtobufCBinaryData * buffer,const char * str)605 void signal_protocol_str_serialize_protobuf(ProtobufCBinaryData *buffer, const char *str)
606 {
607 assert(buffer);
608 assert(str);
609 buffer->data = (uint8_t *)str;
610 buffer->len = strlen(str);
611 }
612
signal_protocol_str_deserialize_protobuf(ProtobufCBinaryData * buffer)613 char *signal_protocol_str_deserialize_protobuf(ProtobufCBinaryData *buffer)
614 {
615 char *str = 0;
616 assert(buffer);
617
618 str = malloc(buffer->len + 1);
619 if(!str) {
620 return 0;
621 }
622
623 memcpy(str, buffer->data, buffer->len);
624 str[buffer->len] = '\0';
625
626 return str;
627 }
628
629 /*------------------------------------------------------------------------*/
630
signal_protocol_store_context_create(signal_protocol_store_context ** context,signal_context * global_context)631 int signal_protocol_store_context_create(signal_protocol_store_context **context, signal_context *global_context)
632 {
633 assert(global_context);
634 *context = malloc(sizeof(signal_protocol_store_context));
635 if(!(*context)) {
636 return SG_ERR_NOMEM;
637 }
638 memset(*context, 0, sizeof(signal_protocol_store_context));
639 (*context)->global_context = global_context;
640 return 0;
641 }
642
signal_protocol_store_context_set_session_store(signal_protocol_store_context * context,const signal_protocol_session_store * store)643 int signal_protocol_store_context_set_session_store(signal_protocol_store_context *context, const signal_protocol_session_store *store)
644 {
645 if(!store) {
646 return SG_ERR_INVAL;
647 }
648 memcpy(&(context->session_store), store, sizeof(signal_protocol_session_store));
649 return 0;
650 }
651
signal_protocol_store_context_set_pre_key_store(signal_protocol_store_context * context,const signal_protocol_pre_key_store * store)652 int signal_protocol_store_context_set_pre_key_store(signal_protocol_store_context *context, const signal_protocol_pre_key_store *store)
653 {
654 if(!store) {
655 return SG_ERR_INVAL;
656 }
657 memcpy(&(context->pre_key_store), store, sizeof(signal_protocol_pre_key_store));
658 return 0;
659 }
660
signal_protocol_store_context_set_signed_pre_key_store(signal_protocol_store_context * context,const signal_protocol_signed_pre_key_store * store)661 int signal_protocol_store_context_set_signed_pre_key_store(signal_protocol_store_context *context, const signal_protocol_signed_pre_key_store *store)
662 {
663 if(!store) {
664 return SG_ERR_INVAL;
665 }
666 memcpy(&(context->signed_pre_key_store), store, sizeof(signal_protocol_signed_pre_key_store));
667 return 0;
668 }
669
signal_protocol_store_context_set_identity_key_store(signal_protocol_store_context * context,const signal_protocol_identity_key_store * store)670 int signal_protocol_store_context_set_identity_key_store(signal_protocol_store_context *context, const signal_protocol_identity_key_store *store)
671 {
672 if(!store) {
673 return SG_ERR_INVAL;
674 }
675 memcpy(&(context->identity_key_store), store, sizeof(signal_protocol_identity_key_store));
676 return 0;
677 }
678
signal_protocol_store_context_set_sender_key_store(signal_protocol_store_context * context,const signal_protocol_sender_key_store * store)679 int signal_protocol_store_context_set_sender_key_store(signal_protocol_store_context *context, const signal_protocol_sender_key_store *store)
680 {
681 if(!store) {
682 return SG_ERR_INVAL;
683 }
684 memcpy(&(context->sender_key_store), store, sizeof(signal_protocol_sender_key_store));
685 return 0;
686 }
687
signal_protocol_store_context_destroy(signal_protocol_store_context * context)688 void signal_protocol_store_context_destroy(signal_protocol_store_context *context)
689 {
690 if(context) {
691 if(context->session_store.destroy_func) {
692 context->session_store.destroy_func(context->session_store.user_data);
693 }
694 if(context->pre_key_store.destroy_func) {
695 context->pre_key_store.destroy_func(context->pre_key_store.user_data);
696 }
697 if(context->signed_pre_key_store.destroy_func) {
698 context->signed_pre_key_store.destroy_func(context->signed_pre_key_store.user_data);
699 }
700 if(context->identity_key_store.destroy_func) {
701 context->identity_key_store.destroy_func(context->identity_key_store.user_data);
702 }
703 if(context->sender_key_store.destroy_func) {
704 context->sender_key_store.destroy_func(context->sender_key_store.user_data);
705 }
706 free(context);
707 }
708 }
709
710 /*------------------------------------------------------------------------*/
711
signal_protocol_session_load_session(signal_protocol_store_context * context,session_record ** record,const signal_protocol_address * address)712 int signal_protocol_session_load_session(signal_protocol_store_context *context, session_record **record, const signal_protocol_address *address)
713 {
714 int result = 0;
715 signal_buffer *buffer = 0;
716 signal_buffer *user_buffer = 0;
717 session_record *result_record = 0;
718
719 assert(context);
720 assert(context->session_store.load_session_func);
721
722 result = context->session_store.load_session_func(
723 &buffer, &user_buffer, address,
724 context->session_store.user_data);
725 if(result < 0) {
726 goto complete;
727 }
728
729 if(result == 0) {
730 if(buffer) {
731 result = SG_ERR_UNKNOWN;
732 goto complete;
733 }
734 result = session_record_create(&result_record, 0, context->global_context);
735 }
736 else if(result == 1) {
737 if(!buffer) {
738 result = -1;
739 goto complete;
740 }
741 result = session_record_deserialize(&result_record,
742 signal_buffer_data(buffer), signal_buffer_len(buffer), context->global_context);
743 }
744 else {
745 result = SG_ERR_UNKNOWN;
746 }
747
748 complete:
749 if(buffer) {
750 signal_buffer_free(buffer);
751 }
752 if(result >= 0) {
753 if(user_buffer) {
754 session_record_set_user_record(result_record, user_buffer);
755 }
756 *record = result_record;
757 }
758 else {
759 signal_buffer_free(user_buffer);
760 }
761 return result;
762 }
763
signal_protocol_session_get_sub_device_sessions(signal_protocol_store_context * context,signal_int_list ** sessions,const char * name,size_t name_len)764 int signal_protocol_session_get_sub_device_sessions(signal_protocol_store_context *context, signal_int_list **sessions, const char *name, size_t name_len)
765 {
766 assert(context);
767 assert(context->session_store.get_sub_device_sessions_func);
768
769 return context->session_store.get_sub_device_sessions_func(
770 sessions, name, name_len,
771 context->session_store.user_data);
772 }
773
signal_protocol_session_store_session(signal_protocol_store_context * context,const signal_protocol_address * address,session_record * record)774 int signal_protocol_session_store_session(signal_protocol_store_context *context, const signal_protocol_address *address, session_record *record)
775 {
776 int result = 0;
777 signal_buffer *buffer = 0;
778 signal_buffer *user_buffer = 0;
779 uint8_t *user_buffer_data = 0;
780 size_t user_buffer_len = 0;
781
782 assert(context);
783 assert(context->session_store.store_session_func);
784 assert(record);
785
786 result = session_record_serialize(&buffer, record);
787 if(result < 0) {
788 goto complete;
789 }
790
791 user_buffer = session_record_get_user_record(record);
792 if(user_buffer) {
793 user_buffer_data = signal_buffer_data(user_buffer);
794 user_buffer_len = signal_buffer_len(user_buffer);
795 }
796
797 result = context->session_store.store_session_func(
798 address,
799 signal_buffer_data(buffer), signal_buffer_len(buffer),
800 user_buffer_data, user_buffer_len,
801 context->session_store.user_data);
802
803 complete:
804 if(buffer) {
805 signal_buffer_free(buffer);
806 }
807
808 return result;
809 }
810
signal_protocol_session_contains_session(signal_protocol_store_context * context,const signal_protocol_address * address)811 int signal_protocol_session_contains_session(signal_protocol_store_context *context, const signal_protocol_address *address)
812 {
813 assert(context);
814 assert(context->session_store.contains_session_func);
815
816 return context->session_store.contains_session_func(
817 address,
818 context->session_store.user_data);
819 }
820
signal_protocol_session_delete_session(signal_protocol_store_context * context,const signal_protocol_address * address)821 int signal_protocol_session_delete_session(signal_protocol_store_context *context, const signal_protocol_address *address)
822 {
823 assert(context);
824 assert(context->session_store.delete_session_func);
825
826 return context->session_store.delete_session_func(
827 address,
828 context->session_store.user_data);
829 }
830
signal_protocol_session_delete_all_sessions(signal_protocol_store_context * context,const char * name,size_t name_len)831 int signal_protocol_session_delete_all_sessions(signal_protocol_store_context *context, const char *name, size_t name_len)
832 {
833 assert(context);
834 assert(context->session_store.delete_all_sessions_func);
835
836 return context->session_store.delete_all_sessions_func(
837 name, name_len,
838 context->session_store.user_data);
839 }
840
841 /*------------------------------------------------------------------------*/
842
signal_protocol_pre_key_load_key(signal_protocol_store_context * context,session_pre_key ** pre_key,uint32_t pre_key_id)843 int signal_protocol_pre_key_load_key(signal_protocol_store_context *context, session_pre_key **pre_key, uint32_t pre_key_id)
844 {
845 int result = 0;
846 signal_buffer *buffer = 0;
847 session_pre_key *result_key = 0;
848
849 assert(context);
850 assert(context->pre_key_store.load_pre_key);
851
852 result = context->pre_key_store.load_pre_key(
853 &buffer, pre_key_id,
854 context->pre_key_store.user_data);
855 if(result < 0) {
856 goto complete;
857 }
858
859 result = session_pre_key_deserialize(&result_key,
860 signal_buffer_data(buffer), signal_buffer_len(buffer), context->global_context);
861
862 complete:
863 if(buffer) {
864 signal_buffer_free(buffer);
865 }
866 if(result >= 0) {
867 *pre_key = result_key;
868 }
869 return result;
870 }
871
signal_protocol_pre_key_store_key(signal_protocol_store_context * context,session_pre_key * pre_key)872 int signal_protocol_pre_key_store_key(signal_protocol_store_context *context, session_pre_key *pre_key)
873 {
874 int result = 0;
875 signal_buffer *buffer = 0;
876 uint32_t id = 0;
877
878 assert(context);
879 assert(context->pre_key_store.store_pre_key);
880 assert(pre_key);
881
882 id = session_pre_key_get_id(pre_key);
883
884 result = session_pre_key_serialize(&buffer, pre_key);
885 if(result < 0) {
886 goto complete;
887 }
888
889 result = context->pre_key_store.store_pre_key(
890 id,
891 signal_buffer_data(buffer), signal_buffer_len(buffer),
892 context->pre_key_store.user_data);
893
894 complete:
895 if(buffer) {
896 signal_buffer_free(buffer);
897 }
898
899 return result;
900 }
901
signal_protocol_pre_key_contains_key(signal_protocol_store_context * context,uint32_t pre_key_id)902 int signal_protocol_pre_key_contains_key(signal_protocol_store_context *context, uint32_t pre_key_id)
903 {
904 int result = 0;
905
906 assert(context);
907 assert(context->pre_key_store.contains_pre_key);
908
909 result = context->pre_key_store.contains_pre_key(
910 pre_key_id, context->pre_key_store.user_data);
911
912 return result;
913 }
914
signal_protocol_pre_key_remove_key(signal_protocol_store_context * context,uint32_t pre_key_id)915 int signal_protocol_pre_key_remove_key(signal_protocol_store_context *context, uint32_t pre_key_id)
916 {
917 int result = 0;
918
919 assert(context);
920 assert(context->pre_key_store.remove_pre_key);
921
922 result = context->pre_key_store.remove_pre_key(
923 pre_key_id, context->pre_key_store.user_data);
924
925 return result;
926 }
927
928 /*------------------------------------------------------------------------*/
929
signal_protocol_signed_pre_key_load_key(signal_protocol_store_context * context,session_signed_pre_key ** pre_key,uint32_t signed_pre_key_id)930 int signal_protocol_signed_pre_key_load_key(signal_protocol_store_context *context, session_signed_pre_key **pre_key, uint32_t signed_pre_key_id)
931 {
932 int result = 0;
933 signal_buffer *buffer = 0;
934 session_signed_pre_key *result_key = 0;
935
936 assert(context);
937 assert(context->signed_pre_key_store.load_signed_pre_key);
938
939 result = context->signed_pre_key_store.load_signed_pre_key(
940 &buffer, signed_pre_key_id,
941 context->signed_pre_key_store.user_data);
942 if(result < 0) {
943 goto complete;
944 }
945
946 result = session_signed_pre_key_deserialize(&result_key,
947 signal_buffer_data(buffer), signal_buffer_len(buffer), context->global_context);
948
949 complete:
950 if(buffer) {
951 signal_buffer_free(buffer);
952 }
953 if(result >= 0) {
954 *pre_key = result_key;
955 }
956 return result;
957 }
958
signal_protocol_signed_pre_key_store_key(signal_protocol_store_context * context,session_signed_pre_key * pre_key)959 int signal_protocol_signed_pre_key_store_key(signal_protocol_store_context *context, session_signed_pre_key *pre_key)
960 {
961 int result = 0;
962 signal_buffer *buffer = 0;
963 uint32_t id = 0;
964
965 assert(context);
966 assert(context->signed_pre_key_store.store_signed_pre_key);
967 assert(pre_key);
968
969 id = session_signed_pre_key_get_id(pre_key);
970
971 result = session_signed_pre_key_serialize(&buffer, pre_key);
972 if(result < 0) {
973 goto complete;
974 }
975
976 result = context->signed_pre_key_store.store_signed_pre_key(
977 id,
978 signal_buffer_data(buffer), signal_buffer_len(buffer),
979 context->signed_pre_key_store.user_data);
980
981 complete:
982 if(buffer) {
983 signal_buffer_free(buffer);
984 }
985
986 return result;
987 }
988
signal_protocol_signed_pre_key_contains_key(signal_protocol_store_context * context,uint32_t signed_pre_key_id)989 int signal_protocol_signed_pre_key_contains_key(signal_protocol_store_context *context, uint32_t signed_pre_key_id)
990 {
991 int result = 0;
992
993 assert(context);
994 assert(context->signed_pre_key_store.contains_signed_pre_key);
995
996 result = context->signed_pre_key_store.contains_signed_pre_key(
997 signed_pre_key_id, context->signed_pre_key_store.user_data);
998
999 return result;
1000 }
1001
signal_protocol_signed_pre_key_remove_key(signal_protocol_store_context * context,uint32_t signed_pre_key_id)1002 int signal_protocol_signed_pre_key_remove_key(signal_protocol_store_context *context, uint32_t signed_pre_key_id)
1003 {
1004 int result = 0;
1005
1006 assert(context);
1007 assert(context->signed_pre_key_store.remove_signed_pre_key);
1008
1009 result = context->signed_pre_key_store.remove_signed_pre_key(
1010 signed_pre_key_id, context->signed_pre_key_store.user_data);
1011
1012 return result;
1013 }
1014
1015 /*------------------------------------------------------------------------*/
1016
signal_protocol_identity_get_key_pair(signal_protocol_store_context * context,ratchet_identity_key_pair ** key_pair)1017 int signal_protocol_identity_get_key_pair(signal_protocol_store_context *context, ratchet_identity_key_pair **key_pair)
1018 {
1019 int result = 0;
1020 signal_buffer *public_buf = 0;
1021 signal_buffer *private_buf = 0;
1022 ec_public_key *public_key = 0;
1023 ec_private_key *private_key = 0;
1024 ratchet_identity_key_pair *result_key = 0;
1025
1026 assert(context);
1027 assert(context->identity_key_store.get_identity_key_pair);
1028
1029 result = context->identity_key_store.get_identity_key_pair(
1030 &public_buf, &private_buf,
1031 context->identity_key_store.user_data);
1032 if(result < 0) {
1033 goto complete;
1034 }
1035
1036 result = curve_decode_point(&public_key, public_buf->data, public_buf->len, context->global_context);
1037 if(result < 0) {
1038 goto complete;
1039 }
1040
1041 result = curve_decode_private_point(&private_key, private_buf->data, private_buf->len, context->global_context);
1042 if(result < 0) {
1043 goto complete;
1044 }
1045
1046 result = ratchet_identity_key_pair_create(&result_key, public_key, private_key);
1047 if(result < 0) {
1048 goto complete;
1049 }
1050
1051 complete:
1052 if(public_buf) {
1053 signal_buffer_free(public_buf);
1054 }
1055 if(private_buf) {
1056 signal_buffer_free(private_buf);
1057 }
1058 if(public_key) {
1059 SIGNAL_UNREF(public_key);
1060 }
1061 if(private_key) {
1062 SIGNAL_UNREF(private_key);
1063 }
1064 if(result >= 0) {
1065 *key_pair = result_key;
1066 }
1067 return result;
1068 }
1069
signal_protocol_identity_get_local_registration_id(signal_protocol_store_context * context,uint32_t * registration_id)1070 int signal_protocol_identity_get_local_registration_id(signal_protocol_store_context *context, uint32_t *registration_id)
1071 {
1072 int result = 0;
1073
1074 assert(context);
1075 assert(context->identity_key_store.get_local_registration_id);
1076
1077 result = context->identity_key_store.get_local_registration_id(
1078 context->identity_key_store.user_data, registration_id);
1079
1080 return result;
1081 }
1082
signal_protocol_identity_save_identity(signal_protocol_store_context * context,const signal_protocol_address * address,ec_public_key * identity_key)1083 int signal_protocol_identity_save_identity(signal_protocol_store_context *context, const signal_protocol_address *address, ec_public_key *identity_key)
1084 {
1085 int result = 0;
1086 signal_buffer *buffer = 0;
1087
1088 assert(context);
1089 assert(context->identity_key_store.save_identity);
1090
1091 if(identity_key) {
1092 result = ec_public_key_serialize(&buffer, identity_key);
1093 if(result < 0) {
1094 goto complete;
1095 }
1096
1097 result = context->identity_key_store.save_identity(
1098 address,
1099 signal_buffer_data(buffer),
1100 signal_buffer_len(buffer),
1101 context->identity_key_store.user_data);
1102 }
1103 else {
1104 result = context->identity_key_store.save_identity(
1105 address, 0, 0,
1106 context->identity_key_store.user_data);
1107 }
1108
1109 complete:
1110 if(buffer) {
1111 signal_buffer_free(buffer);
1112 }
1113
1114 return result;
1115 }
1116
signal_protocol_identity_is_trusted_identity(signal_protocol_store_context * context,const signal_protocol_address * address,ec_public_key * identity_key)1117 int signal_protocol_identity_is_trusted_identity(signal_protocol_store_context *context, const signal_protocol_address *address, ec_public_key *identity_key)
1118 {
1119 int result = 0;
1120 signal_buffer *buffer = 0;
1121
1122 assert(context);
1123 assert(context->identity_key_store.is_trusted_identity);
1124
1125 result = ec_public_key_serialize(&buffer, identity_key);
1126 if(result < 0) {
1127 goto complete;
1128 }
1129
1130 result = context->identity_key_store.is_trusted_identity(
1131 address,
1132 signal_buffer_data(buffer),
1133 signal_buffer_len(buffer),
1134 context->identity_key_store.user_data);
1135 complete:
1136 if(buffer) {
1137 signal_buffer_free(buffer);
1138 }
1139
1140 return result;
1141 }
1142
signal_protocol_sender_key_store_key(signal_protocol_store_context * context,const signal_protocol_sender_key_name * sender_key_name,sender_key_record * record)1143 int signal_protocol_sender_key_store_key(signal_protocol_store_context *context, const signal_protocol_sender_key_name *sender_key_name, sender_key_record *record)
1144 {
1145 int result = 0;
1146 signal_buffer *buffer = 0;
1147 signal_buffer *user_buffer = 0;
1148 uint8_t *user_buffer_data = 0;
1149 size_t user_buffer_len = 0;
1150
1151 assert(context);
1152 assert(context->sender_key_store.store_sender_key);
1153 assert(record);
1154
1155 result = sender_key_record_serialize(&buffer, record);
1156 if(result < 0) {
1157 goto complete;
1158 }
1159
1160 user_buffer = sender_key_record_get_user_record(record);
1161 if(user_buffer) {
1162 user_buffer_data = signal_buffer_data(user_buffer);
1163 user_buffer_len = signal_buffer_len(user_buffer);
1164 }
1165
1166 result = context->sender_key_store.store_sender_key(
1167 sender_key_name,
1168 signal_buffer_data(buffer), signal_buffer_len(buffer),
1169 user_buffer_data, user_buffer_len,
1170 context->sender_key_store.user_data);
1171
1172 complete:
1173 if(buffer) {
1174 signal_buffer_free(buffer);
1175 }
1176
1177 return result;
1178 }
1179
signal_protocol_sender_key_load_key(signal_protocol_store_context * context,sender_key_record ** record,const signal_protocol_sender_key_name * sender_key_name)1180 int signal_protocol_sender_key_load_key(signal_protocol_store_context *context, sender_key_record **record, const signal_protocol_sender_key_name *sender_key_name)
1181 {
1182 int result = 0;
1183 signal_buffer *buffer = 0;
1184 signal_buffer *user_buffer = 0;
1185 sender_key_record *result_record = 0;
1186
1187 assert(context);
1188 assert(context->sender_key_store.load_sender_key);
1189
1190 result = context->sender_key_store.load_sender_key(
1191 &buffer, &user_buffer, sender_key_name,
1192 context->sender_key_store.user_data);
1193 if(result < 0) {
1194 goto complete;
1195 }
1196
1197 if(result == 0) {
1198 if(buffer) {
1199 result = SG_ERR_UNKNOWN;
1200 goto complete;
1201 }
1202 result = sender_key_record_create(&result_record, context->global_context);
1203 }
1204 else if(result == 1) {
1205 if(!buffer) {
1206 result = -1;
1207 goto complete;
1208 }
1209 result = sender_key_record_deserialize(&result_record,
1210 signal_buffer_data(buffer), signal_buffer_len(buffer), context->global_context);
1211 }
1212 else {
1213 result = SG_ERR_UNKNOWN;
1214 }
1215
1216 complete:
1217 if(buffer) {
1218 signal_buffer_free(buffer);
1219 }
1220 if(result >= 0) {
1221 if(user_buffer) {
1222 sender_key_record_set_user_record(result_record, user_buffer);
1223 }
1224 *record = result_record;
1225 }
1226 else {
1227 signal_buffer_free(user_buffer);
1228 }
1229 return result;
1230 }
1231