1 #include "session_record.h"
2 
3 #include <stdlib.h>
4 #include <string.h>
5 #include <assert.h>
6 
7 #include "session_state.h"
8 #include "utlist.h"
9 #include "LocalStorageProtocol.pb-c.h"
10 #include "signal_protocol_internal.h"
11 
12 #define ARCHIVED_STATES_MAX_LENGTH 40
13 
14 struct session_record_state_node
15 {
16     session_state *state;
17     struct session_record_state_node *prev, *next;
18 };
19 
20 struct session_record
21 {
22     signal_type_base base;
23     session_state *state;
24     session_record_state_node *previous_states_head;
25     int is_fresh;
26     signal_buffer *user_record;
27     signal_context *global_context;
28 };
29 
30 static void session_record_free_previous_states(session_record *record);
31 
session_record_create(session_record ** record,session_state * state,signal_context * global_context)32 int session_record_create(session_record **record, session_state *state, signal_context *global_context)
33 {
34     session_record *result = malloc(sizeof(session_record));
35     if(!result) {
36         return SG_ERR_NOMEM;
37     }
38     memset(result, 0, sizeof(session_record));
39     SIGNAL_INIT(result, session_record_destroy);
40 
41     if(!state) {
42         int ret = session_state_create(&result->state, global_context);
43         if(ret < 0) {
44             SIGNAL_UNREF(result);
45             return ret;
46         }
47         result->is_fresh = 1;
48     }
49     else {
50         SIGNAL_REF(state);
51         result->state = state;
52         result->is_fresh = 0;
53     }
54     result->global_context = global_context;
55 
56     *record = result;
57     return 0;
58 }
59 
session_record_serialize(signal_buffer ** buffer,const session_record * record)60 int session_record_serialize(signal_buffer **buffer, const session_record *record)
61 {
62     int result = 0;
63     size_t result_size = 0;
64     unsigned int i = 0;
65     Textsecure__RecordStructure record_structure = TEXTSECURE__RECORD_STRUCTURE__INIT;
66     session_record_state_node *cur_node = 0;
67     signal_buffer *result_buf = 0;
68     size_t len = 0;
69     uint8_t *data = 0;
70 
71     if(!record) {
72         result = SG_ERR_INVAL;
73         goto complete;
74     }
75 
76     if(record->state) {
77         record_structure.currentsession = malloc(sizeof(Textsecure__SessionStructure));
78         if(!record_structure.currentsession) {
79             result = SG_ERR_NOMEM;
80             goto complete;
81         }
82         textsecure__session_structure__init(record_structure.currentsession);
83         result = session_state_serialize_prepare(record->state, record_structure.currentsession);
84         if(result < 0) {
85             goto complete;
86         }
87     }
88 
89     if(record->previous_states_head) {
90         size_t count;
91         DL_COUNT(record->previous_states_head, cur_node, count);
92 
93         if(count > SIZE_MAX / sizeof(Textsecure__SessionStructure *)) {
94             result = SG_ERR_NOMEM;
95             goto complete;
96         }
97 
98         record_structure.previoussessions = malloc(sizeof(Textsecure__SessionStructure *) * count);
99         if(!record_structure.previoussessions) {
100             result = SG_ERR_NOMEM;
101             goto complete;
102         }
103 
104         i = 0;
105         DL_FOREACH(record->previous_states_head, cur_node) {
106             record_structure.previoussessions[i] = malloc(sizeof(Textsecure__SessionStructure));
107             if(!record_structure.previoussessions[i]) {
108                 result = SG_ERR_NOMEM;
109                 break;
110             }
111             textsecure__session_structure__init(record_structure.previoussessions[i]);
112             result = session_state_serialize_prepare(cur_node->state, record_structure.previoussessions[i]);
113             if(result < 0) {
114                 break;
115             }
116             i++;
117         }
118         record_structure.n_previoussessions = i;
119         if(result < 0) {
120             goto complete;
121         }
122     }
123 
124     len = textsecure__record_structure__get_packed_size(&record_structure);
125 
126     result_buf = signal_buffer_alloc(len);
127     if(!result_buf) {
128         result = SG_ERR_NOMEM;
129         goto complete;
130     }
131 
132     data = signal_buffer_data(result_buf);
133     result_size = textsecure__record_structure__pack(&record_structure, data);
134     if(result_size != len) {
135         signal_buffer_free(result_buf);
136         result = SG_ERR_INVALID_PROTO_BUF;
137         result_buf = 0;
138         goto complete;
139     }
140 
141 complete:
142     if(record_structure.currentsession) {
143         session_state_serialize_prepare_free(record_structure.currentsession);
144     }
145     if(record_structure.previoussessions) {
146         for(i = 0; i < record_structure.n_previoussessions; i++) {
147             if(record_structure.previoussessions[i]) {
148                 session_state_serialize_prepare_free(record_structure.previoussessions[i]);
149             }
150         }
151         free(record_structure.previoussessions);
152     }
153 
154     if(result >= 0) {
155         *buffer = result_buf;
156     }
157     return result;
158 }
159 
session_record_deserialize(session_record ** record,const uint8_t * data,size_t len,signal_context * global_context)160 int session_record_deserialize(session_record **record, const uint8_t *data, size_t len, signal_context *global_context)
161 {
162     int result = 0;
163     session_record *result_record = 0;
164     session_state *current_state = 0;
165     session_record_state_node *previous_states_head = 0;
166     Textsecure__RecordStructure *record_structure = 0;
167 
168     record_structure = textsecure__record_structure__unpack(0, len, data);
169     if(!record_structure) {
170         result = SG_ERR_INVALID_PROTO_BUF;
171         goto complete;
172     }
173 
174     if(record_structure->currentsession) {
175         result = session_state_deserialize_protobuf(&current_state, record_structure->currentsession, global_context);
176         if(result < 0) {
177             goto complete;
178         }
179     }
180 
181     result = session_record_create(&result_record, current_state, global_context);
182     if(result < 0) {
183         goto complete;
184     }
185     SIGNAL_UNREF(current_state);
186     current_state = 0;
187     result_record->is_fresh = 0;
188 
189     if(record_structure->n_previoussessions > 0) {
190         unsigned int i;
191         for(i = 0; i < record_structure->n_previoussessions; i++) {
192             Textsecure__SessionStructure *session_structure =
193                     record_structure->previoussessions[i];
194 
195             session_record_state_node *node = malloc(sizeof(session_record_state_node));
196             if(!node) {
197                 result = SG_ERR_NOMEM;
198                 goto complete;
199             }
200 
201             result = session_state_deserialize_protobuf(&node->state, session_structure, global_context);
202             if(result < 0) {
203                 free(node);
204                 goto complete;
205             }
206 
207             DL_APPEND(previous_states_head, node);
208         }
209     }
210     result_record->previous_states_head = previous_states_head;
211     previous_states_head = 0;
212 
213 complete:
214     if(record_structure) {
215         textsecure__record_structure__free_unpacked(record_structure, 0);
216     }
217     if(current_state) {
218         SIGNAL_UNREF(current_state);
219     }
220     if(previous_states_head) {
221         session_record_state_node *cur_node;
222         session_record_state_node *tmp_node;
223         DL_FOREACH_SAFE(previous_states_head, cur_node, tmp_node) {
224             DL_DELETE(previous_states_head, cur_node);
225             free(cur_node);
226         }
227     }
228     if(result_record) {
229         if(result < 0) {
230             SIGNAL_UNREF(result_record);
231         }
232         else {
233             *record = result_record;
234         }
235     }
236 
237     return result;
238 }
239 
session_record_copy(session_record ** record,session_record * other_record,signal_context * global_context)240 int session_record_copy(session_record **record, session_record *other_record, signal_context *global_context)
241 {
242     int result = 0;
243     session_record *result_record = 0;
244     signal_buffer *buffer = 0;
245     size_t len = 0;
246     uint8_t *data = 0;
247 
248     assert(other_record);
249     assert(global_context);
250 
251     result = session_record_serialize(&buffer, other_record);
252     if(result < 0) {
253         goto complete;
254     }
255 
256     data = signal_buffer_data(buffer);
257     len = signal_buffer_len(buffer);
258 
259     result = session_record_deserialize(&result_record, data, len, global_context);
260     if(result < 0) {
261         goto complete;
262     }
263     if(other_record->user_record) {
264         result_record->user_record = signal_buffer_copy(other_record->user_record);
265         if(!result_record->user_record) {
266             result = SG_ERR_NOMEM;
267             goto complete;
268         }
269     }
270 
271 complete:
272     if(buffer) {
273         signal_buffer_free(buffer);
274     }
275     if(result >= 0) {
276         *record = result_record;
277     }
278     else {
279         SIGNAL_UNREF(result_record);
280     }
281     return result;
282 }
283 
session_record_has_session_state(session_record * record,uint32_t version,const ec_public_key * alice_base_key)284 int session_record_has_session_state(session_record *record, uint32_t version, const ec_public_key *alice_base_key)
285 {
286     session_record_state_node *cur_node = 0;
287 
288     assert(record);
289     assert(record->state);
290 
291     if(session_state_get_session_version(record->state) == version &&
292             ec_public_key_compare(
293                     session_state_get_alice_base_key(record->state),
294                     alice_base_key) == 0) {
295         return 1;
296     }
297 
298     DL_FOREACH(record->previous_states_head, cur_node) {
299         if(session_state_get_session_version(cur_node->state) == version &&
300                 ec_public_key_compare(
301                         session_state_get_alice_base_key(cur_node->state),
302                         alice_base_key) == 0) {
303             return 1;
304         }
305     }
306 
307     return 0;
308 }
309 
session_record_get_state(session_record * record)310 session_state *session_record_get_state(session_record *record)
311 {
312     return record->state;
313 }
314 
session_record_set_state(session_record * record,session_state * state)315 void session_record_set_state(session_record *record, session_state *state)
316 {
317     assert(record);
318     assert(state);
319     if(record->state) {
320         SIGNAL_UNREF(record->state);
321     }
322     SIGNAL_REF(state);
323     record->state = state;
324 }
325 
session_record_get_previous_states_head(const session_record * record)326 session_record_state_node *session_record_get_previous_states_head(const session_record *record)
327 {
328     assert(record);
329     return record->previous_states_head;
330 }
331 
session_record_get_previous_states_element(const session_record_state_node * node)332 session_state *session_record_get_previous_states_element(const session_record_state_node *node)
333 {
334     assert(node);
335     return node->state;
336 }
337 
session_record_get_previous_states_next(const session_record_state_node * node)338 session_record_state_node *session_record_get_previous_states_next(const session_record_state_node *node)
339 {
340     assert(node);
341     return node->next;
342 }
343 
session_record_get_previous_states_remove(session_record * record,session_record_state_node * node)344 session_record_state_node *session_record_get_previous_states_remove(session_record *record, session_record_state_node *node)
345 {
346     session_record_state_node *next_node = 0;
347 
348     assert(record);
349     assert(node);
350 
351     next_node = node->next;
352     DL_DELETE(record->previous_states_head, node);
353     SIGNAL_UNREF(node->state);
354     free(node);
355     return next_node;
356 }
357 
session_record_is_fresh(session_record * record)358 int session_record_is_fresh(session_record *record)
359 {
360     assert(record);
361     return record->is_fresh;
362 }
363 
session_record_archive_current_state(session_record * record)364 int session_record_archive_current_state(session_record *record)
365 {
366     int result = 0;
367     session_state *new_state = 0;
368 
369     assert(record);
370 
371     result = session_state_create(&new_state, record->global_context);
372     if(result < 0) {
373         goto complete;
374     }
375 
376     result = session_record_promote_state(record, new_state);
377 
378 complete:
379     SIGNAL_UNREF(new_state);
380     return result;
381 }
382 
session_record_promote_state(session_record * record,session_state * promoted_state)383 int session_record_promote_state(session_record *record, session_state *promoted_state)
384 {
385     int count = 0;
386     session_record_state_node *cur_node = 0;
387     session_record_state_node *tmp_node = 0;
388 
389     assert(record);
390     assert(promoted_state);
391 
392     // Move the previously current state to the list of previous states
393     if(record->state) {
394         session_record_state_node *node = malloc(sizeof(session_record_state_node));
395         if(!node) {
396             return SG_ERR_NOMEM;
397         }
398 
399         node->state = record->state;
400         DL_PREPEND(record->previous_states_head, node);
401         record->state = 0;
402     }
403 
404     // Make the promoted state the current state
405     SIGNAL_REF(promoted_state);
406     record->state = promoted_state;
407 
408     // Remove any previous nodes beyond the maximum length
409     DL_FOREACH_SAFE(record->previous_states_head, cur_node, tmp_node) {
410         count++;
411         if(count > ARCHIVED_STATES_MAX_LENGTH) {
412             DL_DELETE(record->previous_states_head, cur_node);
413             if(cur_node->state) {
414                 SIGNAL_UNREF(cur_node->state);
415             }
416             free(cur_node);
417         }
418     }
419 
420     return 0;
421 }
422 
session_record_free_previous_states(session_record * record)423 static void session_record_free_previous_states(session_record *record)
424 {
425     session_record_state_node *cur_node;
426     session_record_state_node *tmp_node;
427     DL_FOREACH_SAFE(record->previous_states_head, cur_node, tmp_node) {
428         DL_DELETE(record->previous_states_head, cur_node);
429         if(cur_node->state) {
430             SIGNAL_UNREF(cur_node->state);
431         }
432         free(cur_node);
433     }
434     record->previous_states_head = 0;
435 }
436 
session_record_get_user_record(const session_record * record)437 signal_buffer *session_record_get_user_record(const session_record *record)
438 {
439     assert(record);
440     return record->user_record;
441 }
442 
session_record_set_user_record(session_record * record,signal_buffer * user_record)443 void session_record_set_user_record(session_record *record, signal_buffer *user_record)
444 {
445     assert(record);
446     if(record->user_record) {
447         signal_buffer_free(record->user_record);
448     }
449     record->user_record = user_record;
450 }
451 
session_record_destroy(signal_type_base * type)452 void session_record_destroy(signal_type_base *type)
453 {
454     session_record *record = (session_record *)type;
455 
456     if(record->state) {
457         SIGNAL_UNREF(record->state);
458     }
459     session_record_free_previous_states(record);
460 
461     if(record->user_record) {
462         signal_buffer_free(record->user_record);
463     }
464 
465     free(record);
466 }
467