1 #include "session_record.h"
2
3 #include <stdlib.h>
4 #include <string.h>
5 #include <assert.h>
6
7 #include "session_state.h"
8 #include "utlist.h"
9 #include "LocalStorageProtocol.pb-c.h"
10 #include "signal_protocol_internal.h"
11
12 #define ARCHIVED_STATES_MAX_LENGTH 40
13
14 struct session_record_state_node
15 {
16 session_state *state;
17 struct session_record_state_node *prev, *next;
18 };
19
20 struct session_record
21 {
22 signal_type_base base;
23 session_state *state;
24 session_record_state_node *previous_states_head;
25 int is_fresh;
26 signal_buffer *user_record;
27 signal_context *global_context;
28 };
29
30 static void session_record_free_previous_states(session_record *record);
31
session_record_create(session_record ** record,session_state * state,signal_context * global_context)32 int session_record_create(session_record **record, session_state *state, signal_context *global_context)
33 {
34 session_record *result = malloc(sizeof(session_record));
35 if(!result) {
36 return SG_ERR_NOMEM;
37 }
38 memset(result, 0, sizeof(session_record));
39 SIGNAL_INIT(result, session_record_destroy);
40
41 if(!state) {
42 int ret = session_state_create(&result->state, global_context);
43 if(ret < 0) {
44 SIGNAL_UNREF(result);
45 return ret;
46 }
47 result->is_fresh = 1;
48 }
49 else {
50 SIGNAL_REF(state);
51 result->state = state;
52 result->is_fresh = 0;
53 }
54 result->global_context = global_context;
55
56 *record = result;
57 return 0;
58 }
59
session_record_serialize(signal_buffer ** buffer,const session_record * record)60 int session_record_serialize(signal_buffer **buffer, const session_record *record)
61 {
62 int result = 0;
63 size_t result_size = 0;
64 unsigned int i = 0;
65 Textsecure__RecordStructure record_structure = TEXTSECURE__RECORD_STRUCTURE__INIT;
66 session_record_state_node *cur_node = 0;
67 signal_buffer *result_buf = 0;
68 size_t len = 0;
69 uint8_t *data = 0;
70
71 if(!record) {
72 result = SG_ERR_INVAL;
73 goto complete;
74 }
75
76 if(record->state) {
77 record_structure.currentsession = malloc(sizeof(Textsecure__SessionStructure));
78 if(!record_structure.currentsession) {
79 result = SG_ERR_NOMEM;
80 goto complete;
81 }
82 textsecure__session_structure__init(record_structure.currentsession);
83 result = session_state_serialize_prepare(record->state, record_structure.currentsession);
84 if(result < 0) {
85 goto complete;
86 }
87 }
88
89 if(record->previous_states_head) {
90 size_t count;
91 DL_COUNT(record->previous_states_head, cur_node, count);
92
93 if(count > SIZE_MAX / sizeof(Textsecure__SessionStructure *)) {
94 result = SG_ERR_NOMEM;
95 goto complete;
96 }
97
98 record_structure.previoussessions = malloc(sizeof(Textsecure__SessionStructure *) * count);
99 if(!record_structure.previoussessions) {
100 result = SG_ERR_NOMEM;
101 goto complete;
102 }
103
104 i = 0;
105 DL_FOREACH(record->previous_states_head, cur_node) {
106 record_structure.previoussessions[i] = malloc(sizeof(Textsecure__SessionStructure));
107 if(!record_structure.previoussessions[i]) {
108 result = SG_ERR_NOMEM;
109 break;
110 }
111 textsecure__session_structure__init(record_structure.previoussessions[i]);
112 result = session_state_serialize_prepare(cur_node->state, record_structure.previoussessions[i]);
113 if(result < 0) {
114 break;
115 }
116 i++;
117 }
118 record_structure.n_previoussessions = i;
119 if(result < 0) {
120 goto complete;
121 }
122 }
123
124 len = textsecure__record_structure__get_packed_size(&record_structure);
125
126 result_buf = signal_buffer_alloc(len);
127 if(!result_buf) {
128 result = SG_ERR_NOMEM;
129 goto complete;
130 }
131
132 data = signal_buffer_data(result_buf);
133 result_size = textsecure__record_structure__pack(&record_structure, data);
134 if(result_size != len) {
135 signal_buffer_free(result_buf);
136 result = SG_ERR_INVALID_PROTO_BUF;
137 result_buf = 0;
138 goto complete;
139 }
140
141 complete:
142 if(record_structure.currentsession) {
143 session_state_serialize_prepare_free(record_structure.currentsession);
144 }
145 if(record_structure.previoussessions) {
146 for(i = 0; i < record_structure.n_previoussessions; i++) {
147 if(record_structure.previoussessions[i]) {
148 session_state_serialize_prepare_free(record_structure.previoussessions[i]);
149 }
150 }
151 free(record_structure.previoussessions);
152 }
153
154 if(result >= 0) {
155 *buffer = result_buf;
156 }
157 return result;
158 }
159
session_record_deserialize(session_record ** record,const uint8_t * data,size_t len,signal_context * global_context)160 int session_record_deserialize(session_record **record, const uint8_t *data, size_t len, signal_context *global_context)
161 {
162 int result = 0;
163 session_record *result_record = 0;
164 session_state *current_state = 0;
165 session_record_state_node *previous_states_head = 0;
166 Textsecure__RecordStructure *record_structure = 0;
167
168 record_structure = textsecure__record_structure__unpack(0, len, data);
169 if(!record_structure) {
170 result = SG_ERR_INVALID_PROTO_BUF;
171 goto complete;
172 }
173
174 if(record_structure->currentsession) {
175 result = session_state_deserialize_protobuf(¤t_state, record_structure->currentsession, global_context);
176 if(result < 0) {
177 goto complete;
178 }
179 }
180
181 result = session_record_create(&result_record, current_state, global_context);
182 if(result < 0) {
183 goto complete;
184 }
185 SIGNAL_UNREF(current_state);
186 current_state = 0;
187 result_record->is_fresh = 0;
188
189 if(record_structure->n_previoussessions > 0) {
190 unsigned int i;
191 for(i = 0; i < record_structure->n_previoussessions; i++) {
192 Textsecure__SessionStructure *session_structure =
193 record_structure->previoussessions[i];
194
195 session_record_state_node *node = malloc(sizeof(session_record_state_node));
196 if(!node) {
197 result = SG_ERR_NOMEM;
198 goto complete;
199 }
200
201 result = session_state_deserialize_protobuf(&node->state, session_structure, global_context);
202 if(result < 0) {
203 free(node);
204 goto complete;
205 }
206
207 DL_APPEND(previous_states_head, node);
208 }
209 }
210 result_record->previous_states_head = previous_states_head;
211 previous_states_head = 0;
212
213 complete:
214 if(record_structure) {
215 textsecure__record_structure__free_unpacked(record_structure, 0);
216 }
217 if(current_state) {
218 SIGNAL_UNREF(current_state);
219 }
220 if(previous_states_head) {
221 session_record_state_node *cur_node;
222 session_record_state_node *tmp_node;
223 DL_FOREACH_SAFE(previous_states_head, cur_node, tmp_node) {
224 DL_DELETE(previous_states_head, cur_node);
225 free(cur_node);
226 }
227 }
228 if(result_record) {
229 if(result < 0) {
230 SIGNAL_UNREF(result_record);
231 }
232 else {
233 *record = result_record;
234 }
235 }
236
237 return result;
238 }
239
session_record_copy(session_record ** record,session_record * other_record,signal_context * global_context)240 int session_record_copy(session_record **record, session_record *other_record, signal_context *global_context)
241 {
242 int result = 0;
243 session_record *result_record = 0;
244 signal_buffer *buffer = 0;
245 size_t len = 0;
246 uint8_t *data = 0;
247
248 assert(other_record);
249 assert(global_context);
250
251 result = session_record_serialize(&buffer, other_record);
252 if(result < 0) {
253 goto complete;
254 }
255
256 data = signal_buffer_data(buffer);
257 len = signal_buffer_len(buffer);
258
259 result = session_record_deserialize(&result_record, data, len, global_context);
260 if(result < 0) {
261 goto complete;
262 }
263 if(other_record->user_record) {
264 result_record->user_record = signal_buffer_copy(other_record->user_record);
265 if(!result_record->user_record) {
266 result = SG_ERR_NOMEM;
267 goto complete;
268 }
269 }
270
271 complete:
272 if(buffer) {
273 signal_buffer_free(buffer);
274 }
275 if(result >= 0) {
276 *record = result_record;
277 }
278 else {
279 SIGNAL_UNREF(result_record);
280 }
281 return result;
282 }
283
session_record_has_session_state(session_record * record,uint32_t version,const ec_public_key * alice_base_key)284 int session_record_has_session_state(session_record *record, uint32_t version, const ec_public_key *alice_base_key)
285 {
286 session_record_state_node *cur_node = 0;
287
288 assert(record);
289 assert(record->state);
290
291 if(session_state_get_session_version(record->state) == version &&
292 ec_public_key_compare(
293 session_state_get_alice_base_key(record->state),
294 alice_base_key) == 0) {
295 return 1;
296 }
297
298 DL_FOREACH(record->previous_states_head, cur_node) {
299 if(session_state_get_session_version(cur_node->state) == version &&
300 ec_public_key_compare(
301 session_state_get_alice_base_key(cur_node->state),
302 alice_base_key) == 0) {
303 return 1;
304 }
305 }
306
307 return 0;
308 }
309
session_record_get_state(session_record * record)310 session_state *session_record_get_state(session_record *record)
311 {
312 return record->state;
313 }
314
session_record_set_state(session_record * record,session_state * state)315 void session_record_set_state(session_record *record, session_state *state)
316 {
317 assert(record);
318 assert(state);
319 if(record->state) {
320 SIGNAL_UNREF(record->state);
321 }
322 SIGNAL_REF(state);
323 record->state = state;
324 }
325
session_record_get_previous_states_head(const session_record * record)326 session_record_state_node *session_record_get_previous_states_head(const session_record *record)
327 {
328 assert(record);
329 return record->previous_states_head;
330 }
331
session_record_get_previous_states_element(const session_record_state_node * node)332 session_state *session_record_get_previous_states_element(const session_record_state_node *node)
333 {
334 assert(node);
335 return node->state;
336 }
337
session_record_get_previous_states_next(const session_record_state_node * node)338 session_record_state_node *session_record_get_previous_states_next(const session_record_state_node *node)
339 {
340 assert(node);
341 return node->next;
342 }
343
session_record_get_previous_states_remove(session_record * record,session_record_state_node * node)344 session_record_state_node *session_record_get_previous_states_remove(session_record *record, session_record_state_node *node)
345 {
346 session_record_state_node *next_node = 0;
347
348 assert(record);
349 assert(node);
350
351 next_node = node->next;
352 DL_DELETE(record->previous_states_head, node);
353 SIGNAL_UNREF(node->state);
354 free(node);
355 return next_node;
356 }
357
session_record_is_fresh(session_record * record)358 int session_record_is_fresh(session_record *record)
359 {
360 assert(record);
361 return record->is_fresh;
362 }
363
session_record_archive_current_state(session_record * record)364 int session_record_archive_current_state(session_record *record)
365 {
366 int result = 0;
367 session_state *new_state = 0;
368
369 assert(record);
370
371 result = session_state_create(&new_state, record->global_context);
372 if(result < 0) {
373 goto complete;
374 }
375
376 result = session_record_promote_state(record, new_state);
377
378 complete:
379 SIGNAL_UNREF(new_state);
380 return result;
381 }
382
session_record_promote_state(session_record * record,session_state * promoted_state)383 int session_record_promote_state(session_record *record, session_state *promoted_state)
384 {
385 int count = 0;
386 session_record_state_node *cur_node = 0;
387 session_record_state_node *tmp_node = 0;
388
389 assert(record);
390 assert(promoted_state);
391
392 // Move the previously current state to the list of previous states
393 if(record->state) {
394 session_record_state_node *node = malloc(sizeof(session_record_state_node));
395 if(!node) {
396 return SG_ERR_NOMEM;
397 }
398
399 node->state = record->state;
400 DL_PREPEND(record->previous_states_head, node);
401 record->state = 0;
402 }
403
404 // Make the promoted state the current state
405 SIGNAL_REF(promoted_state);
406 record->state = promoted_state;
407
408 // Remove any previous nodes beyond the maximum length
409 DL_FOREACH_SAFE(record->previous_states_head, cur_node, tmp_node) {
410 count++;
411 if(count > ARCHIVED_STATES_MAX_LENGTH) {
412 DL_DELETE(record->previous_states_head, cur_node);
413 if(cur_node->state) {
414 SIGNAL_UNREF(cur_node->state);
415 }
416 free(cur_node);
417 }
418 }
419
420 return 0;
421 }
422
session_record_free_previous_states(session_record * record)423 static void session_record_free_previous_states(session_record *record)
424 {
425 session_record_state_node *cur_node;
426 session_record_state_node *tmp_node;
427 DL_FOREACH_SAFE(record->previous_states_head, cur_node, tmp_node) {
428 DL_DELETE(record->previous_states_head, cur_node);
429 if(cur_node->state) {
430 SIGNAL_UNREF(cur_node->state);
431 }
432 free(cur_node);
433 }
434 record->previous_states_head = 0;
435 }
436
session_record_get_user_record(const session_record * record)437 signal_buffer *session_record_get_user_record(const session_record *record)
438 {
439 assert(record);
440 return record->user_record;
441 }
442
session_record_set_user_record(session_record * record,signal_buffer * user_record)443 void session_record_set_user_record(session_record *record, signal_buffer *user_record)
444 {
445 assert(record);
446 if(record->user_record) {
447 signal_buffer_free(record->user_record);
448 }
449 record->user_record = user_record;
450 }
451
session_record_destroy(signal_type_base * type)452 void session_record_destroy(signal_type_base *type)
453 {
454 session_record *record = (session_record *)type;
455
456 if(record->state) {
457 SIGNAL_UNREF(record->state);
458 }
459 session_record_free_previous_states(record);
460
461 if(record->user_record) {
462 signal_buffer_free(record->user_record);
463 }
464
465 free(record);
466 }
467