1 #include "trie.h"
2 #include <math.h>
3 
4 /*
5 * Maps the 256 characters (suitable for UTF-8 strings) to array indices
6 * ordered by frequency of usage in Wikipedia titles.
7 * In practice the order of the chars shouldn't matter for larger key sets
8 * but may save space for a small number of keys
9 */
10 uint8_t DEFAULT_ALPHABET[] = {
11 32, 97, 101, 105, 111, 110, 114, 0, 116, 108, 115, 117, 104, 99, 100, 109,
12 103, 121, 83, 112, 67, 98, 107, 77, 65, 102, 118, 66, 80, 84, 41, 40,
13 119, 82, 72, 68, 76, 71, 70, 87, 49, 44, 78, 75, 69, 74, 73, 48,
14 195, 122, 45, 50, 57, 79, 86, 46, 120, 85, 106, 39, 56, 51, 52, 89,
15 128, 226, 147, 55, 53, 54, 197, 113, 196, 90, 169, 161, 81, 179, 58, 88,
16 173, 188, 141, 182, 153, 177, 38, 130, 135, 164, 159, 47, 168, 33, 186, 167,
17 129, 200, 131, 162, 155, 184, 163, 171, 160, 137, 132, 190, 133, 34, 225, 187,
18 165, 189, 176, 63, 201, 140, 154, 180, 151, 170, 145, 175, 43, 152, 150, 166,
19 158, 194, 198, 178, 144, 181, 148, 134, 136, 42, 185, 174, 156, 143, 172, 191,
20 142, 96, 59, 202, 139, 183, 64, 206, 157, 61, 146, 36, 37, 199, 149, 126,
21 229, 230, 204, 233, 231, 207, 138, 208, 232, 92, 227, 228, 209, 94, 224, 239,
22 217, 205, 221, 218, 211, 4, 8, 12, 16, 20, 24, 28, 60, 203, 215, 219,
23 223, 235, 243, 247, 251, 124, 254, 3, 7, 11, 15, 19, 23, 27, 31, 35,
24 192, 212, 216, 91, 220, 95, 236, 240, 244, 248, 123, 252, 127, 2, 6, 10,
25 14, 18, 22, 26, 30, 62, 193, 213, 237, 241, 245, 249, 253, 1, 5, 9,
26 13, 17, 21, 25, 29, 210, 214, 93, 222, 234, 238, 242, 246, 250, 125, 255
27 };
28 
29 
30 /*
31 Constructors
32 */
33 
trie_new_empty(uint8_t * alphabet,uint32_t alphabet_size)34 static trie_t *trie_new_empty(uint8_t *alphabet, uint32_t alphabet_size) {
35     trie_t *self = calloc(1, sizeof(trie_t));
36     if (!self)
37         goto exit_no_malloc;
38 
39     self->nodes = trie_node_array_new_size(DEFAULT_NODE_ARRAY_SIZE);
40     if (!self->nodes)
41         goto exit_trie_created;
42 
43     self->null_node = NULL_NODE;
44 
45     self->tail = uchar_array_new_size(1);
46     if (!self->tail)
47         goto exit_node_array_created;
48 
49     self->alphabet = malloc(alphabet_size);
50     if (!self->alphabet)
51         goto exit_tail_created;
52     memcpy(self->alphabet, alphabet, alphabet_size);
53 
54     self->alphabet_size = alphabet_size;
55 
56     self->num_keys = 0;
57 
58     for (int i = 0; i < self->alphabet_size; i++) {
59         self->alpha_map[alphabet[i]] = i;
60         log_debug("setting alpha_map[%d] = %d\n", alphabet[i], i);
61     }
62 
63     self->data = trie_data_array_new_size(1);
64     if (!self->data)
65         goto exit_alphabet_created;
66 
67     return self;
68 
69 exit_alphabet_created:
70     free(self->alphabet);
71 exit_tail_created:
72     uchar_array_destroy(self->tail);
73 exit_node_array_created:
74     trie_node_array_destroy(self->nodes);
75 exit_trie_created:
76     free(self);
77 exit_no_malloc:
78     return NULL;
79 }
80 
trie_new_alphabet(uint8_t * alphabet,uint32_t alphabet_size)81 trie_t *trie_new_alphabet(uint8_t *alphabet, uint32_t alphabet_size) {
82     trie_t *self = trie_new_empty(alphabet, alphabet_size);
83     if (!self)
84         return NULL;
85 
86     trie_node_array_push(self->nodes, (trie_node_t){0, 0});
87     // Circular reference  point for first and last free nodes in the linked list
88     trie_node_array_push(self->nodes, (trie_node_t){-1, -1});
89     // Root node
90     trie_node_array_push(self->nodes, (trie_node_t){TRIE_POOL_BEGIN, 0});
91 
92     uchar_array_push(self->tail, '\0');
93     // Since data indexes are negative integers, index 0 is not valid, so pad it
94     trie_data_array_push(self->data, (trie_data_node_t){0, 0});
95 
96     return self;
97 }
98 
trie_new(void)99 trie_t *trie_new(void) {
100     return trie_new_alphabet(DEFAULT_ALPHABET, sizeof(DEFAULT_ALPHABET));
101 }
102 
trie_node_is_free(trie_node_t node)103 inline bool trie_node_is_free(trie_node_t node) {
104     return node.check < 0;
105 }
106 
trie_get_node(trie_t * self,uint32_t index)107 inline trie_node_t trie_get_node(trie_t *self, uint32_t index) {
108     if ((index >= self->nodes->n) || index < ROOT_NODE_ID) return self->null_node;
109     return self->nodes->a[index];
110 }
111 
trie_set_base(trie_t * self,uint32_t index,int32_t base)112 inline void trie_set_base(trie_t *self, uint32_t index, int32_t base) {
113     log_debug("Setting base at %d to %d\n", index, base);
114     self->nodes->a[index].base = base;
115 }
116 
trie_set_check(trie_t * self,uint32_t index,int32_t check)117 inline void trie_set_check(trie_t *self, uint32_t index, int32_t check) {
118     log_debug("Setting check at %d to %d\n", index, check);
119     self->nodes->a[index].check = check;
120 }
121 
122 
trie_get_root(trie_t * self)123 inline trie_node_t trie_get_root(trie_t *self) {
124     return self->nodes->a[ROOT_NODE_ID];
125 }
126 
trie_get_free_list(trie_t * self)127 inline trie_node_t trie_get_free_list(trie_t *self) {
128     return self->nodes->a[FREE_LIST_ID];
129 }
130 
131 
132 /*
133 * Private implementation
134 */
135 
136 
137 
trie_extend(trie_t * self,uint32_t to_index)138 static bool trie_extend(trie_t *self, uint32_t to_index) {
139     uint32_t new_begin, i, free_tail;
140 
141     if (to_index <= 0 || TRIE_MAX_INDEX <= to_index)
142         return false;
143 
144     if (to_index < self->nodes->n)
145         return true;
146 
147     new_begin = (uint32_t)self->nodes->n;
148 
149     for (i = new_begin; i < to_index + 1; i++) {
150         trie_node_array_push(self->nodes, (trie_node_t){-(i-1), -(i+1)});
151     }
152 
153     trie_node_t free_list_node = trie_get_free_list(self);
154     free_tail = -free_list_node.base;
155     trie_set_check(self, free_tail, -new_begin);
156     trie_set_base(self, new_begin, -free_tail);
157     trie_set_check(self, to_index, -FREE_LIST_ID);
158     trie_set_base(self, FREE_LIST_ID, -to_index);
159 
160     return true;
161 }
162 
trie_make_room_for(trie_t * self,uint32_t next_id)163 void trie_make_room_for(trie_t *self, uint32_t next_id) {
164     if (next_id+self->alphabet_size >= self->nodes->n) {
165         trie_extend(self, next_id+self->alphabet_size);
166         log_debug("extended to %zu\n", self->nodes->n);
167     }
168 }
169 
trie_set_node(trie_t * self,uint32_t index,trie_node_t node)170 static inline void trie_set_node(trie_t *self, uint32_t index, trie_node_t node) {
171     log_debug("setting node, index=%d, node=(%d,%d)\n", index, node.base, node.check);
172     self->nodes->a[index] = node;
173 }
174 
trie_init_node(trie_t * self,uint32_t index)175 static void trie_init_node(trie_t *self, uint32_t index) {
176     int32_t prev, next;
177 
178     trie_node_t node = trie_get_node(self, index);
179     prev = -node.base;
180     next = -node.check;
181 
182     trie_set_check(self, prev, -next);
183     trie_set_base(self, next, -prev);
184 
185 }
186 
trie_free_node(trie_t * self,uint32_t index)187 static void trie_free_node(trie_t *self, uint32_t index) {
188     int32_t i, prev;
189 
190     trie_node_t free_list_node = trie_get_free_list(self);
191     trie_node_t node;
192     i = -free_list_node.check;
193     while (i != FREE_LIST_ID && i < index) {
194         node = trie_get_node(self, i);
195         i = -node.check;
196     }
197 
198     node = trie_get_node(self, i);
199     prev = -node.base;
200 
201     trie_set_node(self, index, (trie_node_t){-prev, -i});
202 
203     trie_set_check(self, prev, -index);
204     trie_set_base(self, i, -index);
205 }
206 
207 
trie_node_has_children(trie_t * self,uint32_t node_id)208 static bool trie_node_has_children(trie_t *self, uint32_t node_id) {
209     uint32_t index;
210     if (node_id > self->nodes->n)
211         return false;
212     trie_node_t node = trie_get_node(self, node_id);
213     if (node.base < 0)
214         return false;
215     for (int i = 0; i < self->alphabet_size; i++) {
216         unsigned char c = self->alphabet[i];
217         index = trie_get_transition_index(self, node, c);
218         if (index < self->nodes->n && (uint32_t)trie_get_node(self, index).check == node_id)
219             return true;
220     }
221     return false;
222 }
223 
trie_prune_up_to(trie_t * self,uint32_t p,uint32_t s)224 static void trie_prune_up_to(trie_t *self, uint32_t p, uint32_t s) {
225     log_debug("Pruning from %d to %d\n", s, p);
226     log_debug("%d has_children=%d\n", s, trie_node_has_children(self, s));
227     while (p != s && !trie_node_has_children(self, s)) {
228         uint32_t parent = trie_get_node(self, s).check;
229         trie_free_node(self, s);
230         s = parent;
231     }
232 }
233 
trie_prune(trie_t * self,uint32_t s)234 static void trie_prune(trie_t *self, uint32_t s) {
235     trie_prune_up_to(self, ROOT_NODE_ID, s);
236 }
237 
trie_get_transition_chars(trie_t * self,uint32_t node_id,unsigned char * transitions,uint32_t * num_transitions)238 static void trie_get_transition_chars(trie_t *self, uint32_t node_id, unsigned char *transitions, uint32_t *num_transitions) {
239     uint32_t index;
240     uint32_t j = 0;
241     trie_node_t node = trie_get_node(self, node_id);
242     for (int i = 0; i < self->alphabet_size; i++) {
243         unsigned char c = self->alphabet[i];
244         index = trie_get_transition_index(self, node, c);
245         if (index < self->nodes->n && trie_get_node(self, index).check == node_id) {
246             log_debug("adding transition char %c to index %d\n", c, j);
247             transitions[j++] = c;
248         }
249     }
250 
251     *num_transitions = j;
252 }
253 
254 
trie_can_fit_transitions(trie_t * self,uint32_t node_id,unsigned char * transitions,uint32_t num_transitions)255 static bool trie_can_fit_transitions(trie_t *self, uint32_t node_id, unsigned char *transitions, uint32_t num_transitions) {
256     uint32_t i;
257     uint32_t char_index, index;
258 
259     for (i = 0; i < num_transitions; i++) {
260         unsigned char c = transitions[i];
261         char_index = trie_get_char_index(self, c);
262         index = node_id + char_index;
263         trie_node_t node = trie_get_node(self, index);
264         if (node_id > TRIE_MAX_INDEX - char_index || !trie_node_is_free(node)) {
265             return false;
266         }
267 
268     }
269     return true;
270 
271 }
272 
trie_find_new_base(trie_t * self,unsigned char * transitions,uint32_t num_transitions)273 static uint32_t trie_find_new_base(trie_t *self, unsigned char *transitions, uint32_t num_transitions) {
274     uint32_t first_char_index = trie_get_char_index(self, transitions[0]);
275 
276     trie_node_t node = trie_get_free_list(self);
277     uint32_t index = -node.check;
278 
279     while (index != FREE_LIST_ID && index < first_char_index + TRIE_POOL_BEGIN) {
280         node = trie_get_node(self, index);
281         index = -node.check;
282     }
283 
284 
285     if (index == FREE_LIST_ID) {
286         for (index = first_char_index + TRIE_POOL_BEGIN; ; index++) {
287             if (!trie_extend(self, index)) {
288                 log_error("Trie index error extending to %d\n", index);
289                 return TRIE_INDEX_ERROR;
290             }
291             node = trie_get_node(self, index);
292             if (node.check < 0)
293                 break;
294         }
295     }
296 
297     // search for next free cell that fits the transitions
298     while (!trie_can_fit_transitions(self, index - first_char_index, transitions, num_transitions)) {
299         trie_node_t node = trie_get_node(self, index);
300         if (-node.check == FREE_LIST_ID) {
301             if (!trie_extend(self, (uint32_t) self->nodes->n + self->alphabet_size)) {
302                 log_error("Trie index error extending to %d\n", index);
303                 return TRIE_INDEX_ERROR;
304             }
305             node = trie_get_node(self, index);
306         }
307 
308         index = -node.check;
309 
310     }
311 
312     return index - first_char_index;
313 
314 }
315 
trie_required_size(trie_t * self,uint32_t index)316 static size_t trie_required_size(trie_t *self, uint32_t index) {
317     size_t array_size = (size_t)self->nodes->m;
318     // Make sure we have enough space in the array
319     while (array_size < (TRIE_POOL_BEGIN+index)) {
320         array_size *= 2;
321     }
322     return array_size;
323 }
324 
trie_relocate_base(trie_t * self,uint32_t current_index,int32_t new_base)325 static void trie_relocate_base(trie_t *self, uint32_t current_index, int32_t new_base) {
326     log_debug("Relocating base at %d\n", current_index);
327     uint32_t i;
328 
329     trie_make_room_for(self, new_base);
330 
331     trie_node_t old_node = trie_get_node(self, current_index);
332 
333     uint32_t num_transitions = 0;
334     unsigned char transitions[self->alphabet_size];
335     trie_get_transition_chars(self, current_index, transitions, &num_transitions);
336 
337     for (i = 0; i < num_transitions; i++) {
338         unsigned char c = transitions[i];
339 
340         uint32_t char_index = trie_get_char_index(self, c);
341 
342         uint32_t old_index = old_node.base + char_index;
343         uint32_t new_index = new_base + char_index;
344 
345         log_debug("old_index=%d\n", old_index);
346         trie_node_t old_transition = trie_get_node(self, old_index);
347 
348         trie_init_node(self, new_index);
349         trie_set_node(self, new_index, (trie_node_t){old_transition.base, current_index});
350 
351         /*
352         *  All transitions out of old_index are now owned by new_index
353         *  set check values appropriately
354         */
355         if (old_transition.base > 0) {  // do nothing in the case of a tail pointer
356             for (uint32_t j = 0; j < self->alphabet_size; j++) {
357                 unsigned char c = self->alphabet[j];
358                 uint32_t index = trie_get_transition_index(self, old_transition, c);
359                 if (index < self->nodes->n && trie_get_node(self, index).check == old_index) {
360                     trie_set_check(self, index, new_index);
361                 }
362             }
363         }
364 
365         // Free the node at old_index
366         log_debug("freeing node at %d\n", old_index);
367         trie_free_node(self, old_index);
368 
369     }
370 
371     trie_set_base(self, current_index, new_base);
372 }
373 
374 
375 
376 /*
377 * Public methods
378 */
379 
trie_get_char_index(trie_t * self,unsigned char c)380 inline uint32_t trie_get_char_index(trie_t *self, unsigned char c) {
381     return self->alpha_map[(uint8_t)c] + 1;
382 }
383 
trie_get_transition_index(trie_t * self,trie_node_t node,unsigned char c)384 inline uint32_t trie_get_transition_index(trie_t *self, trie_node_t node, unsigned char c) {
385     uint32_t char_index = trie_get_char_index(self, c);
386     return node.base + char_index;
387 }
388 
trie_get_transition(trie_t * self,trie_node_t node,unsigned char c)389 inline trie_node_t trie_get_transition(trie_t *self, trie_node_t node, unsigned char c) {
390    uint32_t index = trie_get_transition_index(self, node, c);
391 
392     if (index >= self->nodes->n) {
393         return self->null_node;
394     } else {
395         return self->nodes->a[index];
396     }
397 
398 }
399 
trie_add_tail(trie_t * self,unsigned char * tail)400 void trie_add_tail(trie_t *self, unsigned char *tail) {
401     log_debug("Adding tail: %s\n", tail);
402     for (; *tail; tail++) {
403         uchar_array_push(self->tail, *tail);
404     }
405 
406     uchar_array_push(self->tail, '\0');
407 }
408 
trie_set_tail(trie_t * self,unsigned char * tail,uint32_t tail_pos)409 void trie_set_tail(trie_t *self, unsigned char *tail, uint32_t tail_pos) {
410     log_debug("Setting tail: %s at pos %d\n", tail, tail_pos);
411     size_t tail_len = strlen((char *)tail);
412     ssize_t num_appends = (ssize_t)(tail_pos + tail_len) - self->tail->n;
413     int i = 0;
414 
415     // Pad with 0s if we're short
416     if (num_appends > 0) {
417         for (i = 0; i < num_appends; i++) {
418             uchar_array_push(self->tail, '\0');
419         }
420     }
421 
422     for (i = tail_pos; *tail && i < self->tail->n; i++, tail++) {
423         self->tail->a[i] = *tail;
424     }
425     self->tail->a[i] = '\0';
426 }
427 
428 
trie_add_transition(trie_t * self,uint32_t node_id,unsigned char c)429 uint32_t trie_add_transition(trie_t *self, uint32_t node_id, unsigned char c) {
430     uint32_t next_id;
431     trie_node_t node, next;
432     uint32_t new_base;
433 
434 
435     node = trie_get_node(self, node_id);
436     uint32_t char_index = trie_get_char_index(self, c);
437 
438     log_debug("adding transition %c to node_id %d + char_index %d, base=%d, check=%d\n", c, node_id, char_index, node.base, node.check);
439 
440 
441     if (node.base > 0) {
442         log_debug("node.base > 0\n");
443         next_id = node.base + char_index;
444         log_debug("next_id=%d\n", next_id);
445         trie_make_room_for(self, next_id);
446 
447         next = trie_get_node(self, next_id);
448 
449         if (next.check == node_id) {
450             return next_id;
451         }
452 
453         log_debug("next.base=%d, next.check=%d\n", next.base, next.check);
454 
455         if (node.base > TRIE_MAX_INDEX - char_index || !trie_node_is_free(next)) {
456             log_debug("node.base > TRIE_MAX_INDEX\n");
457             uint32_t num_transitions;
458             unsigned char transitions[self->alphabet_size];
459             trie_get_transition_chars(self, node_id, transitions, &num_transitions);
460 
461             transitions[num_transitions++] = c;
462             new_base = trie_find_new_base(self, transitions, num_transitions);
463 
464             trie_relocate_base(self, node_id, new_base);
465             next_id = new_base + char_index;
466         }
467 
468     } else {
469         unsigned char transitions[] = {c};
470         new_base = trie_find_new_base(self, transitions, 1);
471         log_debug("Found base for transition char %c, base=%d\n", c, new_base);
472 
473         trie_set_base(self, node_id, new_base);
474         next_id = new_base + char_index;
475     }
476     log_debug("init_node\n");
477     trie_init_node(self, next_id);
478     log_debug("setting check\n");
479     trie_set_check(self, next_id, node_id);
480 
481     return next_id;
482 }
483 
trie_separate_tail(trie_t * self,uint32_t from_index,unsigned char * tail,uint32_t data)484 int32_t trie_separate_tail(trie_t *self, uint32_t from_index, unsigned char *tail, uint32_t data) {
485     unsigned char c = *tail;
486     int32_t index = trie_add_transition(self, from_index, c);
487 
488     if (*tail != '\0') tail++;
489 
490     log_debug("Separating node at index %d into char %c with tail %s\n", from_index, c, tail);
491     trie_set_base(self, index, -1 * (int32_t)self->data->n);
492 
493     trie_data_array_push(self->data, (trie_data_node_t){(uint32_t)self->tail->n, data});
494     trie_add_tail(self, tail);
495 
496     return index;
497 }
498 
trie_tail_merge(trie_t * self,uint32_t old_node_id,unsigned char * suffix,uint32_t data)499 void trie_tail_merge(trie_t *self, uint32_t old_node_id, unsigned char *suffix, uint32_t data) {
500     unsigned char c;
501     uint32_t next_id;
502 
503     trie_node_t old_node = trie_get_node(self, old_node_id);
504     int32_t old_data_index = -1*old_node.base;
505     trie_data_node_t old_data_node = self->data->a[old_data_index];
506     uint32_t old_tail_pos = old_data_node.tail;
507 
508     unsigned char *original_tail = self->tail->a + old_tail_pos;
509     unsigned char *old_tail = original_tail;
510     log_debug("Merging existing tail %s with new tail %s, node_id=%d\n", original_tail, suffix, old_node_id);
511 
512     size_t common_prefix = string_common_prefix((char *)old_tail, (char *)suffix);
513     size_t old_tail_len = strlen((char *)old_tail);
514     size_t suffix_len = strlen((char *)suffix);
515     if (common_prefix == old_tail_len && old_tail_len == suffix_len) {
516         log_debug("Key already exists, setting value to %d\n", data);
517         self->data->a[old_data_index] = (trie_data_node_t) {old_tail_pos, data};
518         return;
519     }
520 
521     uint32_t node_id = old_node_id;
522     log_debug("common_prefix=%zu\n", common_prefix);
523 
524     for (size_t i = 0; i < common_prefix; i++) {
525         c = old_tail[i];
526         log_debug("merge tail, c=%c, node_id=%d\n", c, node_id);
527         next_id = trie_add_transition(self, node_id, c);
528         if (next_id == TRIE_INDEX_ERROR) {
529             goto exit_prune;
530         }
531         node_id = next_id;
532     }
533 
534     uint32_t old_tail_index = trie_add_transition(self, node_id, *(old_tail+common_prefix));
535     log_debug("old_tail_index=%d\n", old_tail_index);
536     if (old_tail_index == TRIE_INDEX_ERROR) {
537         goto exit_prune;
538     }
539 
540     old_tail += common_prefix;
541     if (*old_tail != '\0') {
542         old_tail++;
543     }
544 
545     trie_set_base(self, old_tail_index, -1 * old_data_index);
546     trie_set_tail(self, old_tail, old_tail_pos);
547 
548     trie_separate_tail(self, node_id, suffix+common_prefix, data);
549     return;
550 
551 exit_prune:
552     trie_prune_up_to(self, old_node_id, node_id);
553     trie_set_tail(self, original_tail, old_tail_pos);
554     return;
555 }
556 
557 
558 
trie_print(trie_t * self)559 void trie_print(trie_t *self) {
560     printf("Trie\n");
561     printf("num_nodes=%zu, alphabet_size=%d\n\n", self->nodes->n, self->alphabet_size);
562     for (size_t i = 0; i < self->nodes->n; i++) {
563         int32_t base = self->nodes->a[i].base;
564         int32_t check = self->nodes->a[i].check;
565 
566         int check_width = abs(check) > 9 ? (int) log10(abs(check))+1 : 1;
567         int base_width = abs(base) > 9 ? (int) log10(abs(base))+1 : 1;
568         if (base < 0) base_width++;
569         if (check < 0) check_width++;
570         int width = base_width > check_width ? base_width : check_width;
571         printf("%*d ", width, base);
572     }
573     printf("\n");
574 
575     for (size_t i = 0; i < self->nodes->n; i++) {
576         int32_t base = self->nodes->a[i].base;
577         int32_t check = self->nodes->a[i].check;
578 
579         int check_width = abs(check) > 9 ? (int) log10(abs(check)) + 1 : 1;
580         int base_width = abs(base) > 9 ? (int) log10(abs(base)) + 1 : 1;
581         if (base < 0) base_width++;
582         if (check < 0) check_width++;
583         int width = base_width > check_width ? base_width : check_width;
584         printf("%*d ", width, check);
585     }
586     printf("\n");
587     for (size_t i = 0; i < self->tail->n; i++) {
588         printf("%c ", self->tail->a[i]);
589     }
590     printf("\n");
591     for (size_t i = 0; i < self->data->n; i++) {
592         uint32_t tail = self->data->a[i].tail;
593         uint32_t data = self->data->a[i].data;
594 
595         int tail_width = tail > 9 ? (int) log10(tail)+1 : 1;
596         int data_width = data > 9 ? (int) log10(data)+1 : 1;
597 
598         int width = tail_width > data_width ? tail_width : data_width;
599         printf("%*d ", width, tail);
600 
601     }
602     printf("\n");
603     for (size_t i = 0; i < self->data->n; i++) {
604         uint32_t tail = self->data->a[i].tail;
605         uint32_t data = self->data->a[i].data;
606 
607         int tail_width = tail > 9 ? (int) log10(tail)+1 : 1;
608         int data_width = data > 9 ? (int) log10(data)+1 : 1;
609 
610         int width = tail_width > data_width ? tail_width : data_width;
611         printf("%*d ", width, data);
612 
613     }
614     printf("\n");
615 
616 }
617 
trie_add_at_index(trie_t * self,uint32_t node_id,char * key,size_t len,uint32_t data)618 bool trie_add_at_index(trie_t *self, uint32_t node_id, char *key, size_t len, uint32_t data) {
619     if (len == 2 && (key[0] == TRIE_SUFFIX_CHAR[0] || key[0] == TRIE_PREFIX_CHAR[0]) && key[1] == '\0') {
620         return false;
621     }
622 
623     unsigned char *ptr = (unsigned char *)key;
624     uint32_t last_node_id = node_id;
625     trie_node_t last_node = trie_get_node(self, node_id);
626     if (last_node.base == NULL_NODE_ID) {
627         log_debug("last_node.base == NULL_NODE_ID, node_id = %d\n", node_id);
628         return false;
629     }
630 
631     trie_node_t node;
632 
633     // Walks node until prefix reached, including the trailing \0
634 
635     for (size_t i = 0; i < len; ptr++, i++, last_node_id = node_id, last_node = node) {
636 
637         log_debug("--- char=%d\n", *ptr);
638         node_id = trie_get_transition_index(self, last_node, *ptr);
639         log_debug("node_id=%d, last_node.base=%d, last_node.check=%d, char_index=%d\n", node_id, last_node.base, last_node.check, trie_get_char_index(self, *ptr));
640 
641         if (node_id != NULL_NODE_ID) {
642             trie_make_room_for(self, node_id);
643         }
644 
645         node = trie_get_node(self, node_id);
646         log_debug("node.check=%d, last_node_id=%d, node.base=%d\n", node.check, last_node_id, node.base);
647 
648         if (node.check < 0 || (node.check != last_node_id)) {
649             log_debug("last_node_id=%d, ptr=%s, tail_pos=%zu\n", last_node_id,  ptr, self->tail->n);
650             trie_separate_tail(self, last_node_id, ptr, data);
651             break;
652         } else if (node.base < 0 && node.check == last_node_id) {
653             log_debug("Case 3 insertion\n");
654             trie_tail_merge(self, node_id, ptr + 1, data);
655             break;
656         }
657     }
658 
659     self->num_keys++;
660     return true;
661 }
662 
663 
trie_add(trie_t * self,char * key,uint32_t data)664 inline bool trie_add(trie_t *self, char *key, uint32_t data) {
665     size_t len = strlen(key);
666     if (len == 0) return false;
667     return trie_add_at_index(self, ROOT_NODE_ID, key, len + 1, data);
668 }
669 
trie_add_len(trie_t * self,char * key,size_t len,uint32_t data)670 inline bool trie_add_len(trie_t *self, char *key, size_t len, uint32_t data) {
671     return trie_add_at_index(self, ROOT_NODE_ID, key, len, data);
672 }
673 
trie_add_prefix_at_index(trie_t * self,char * key,uint32_t start_node_id,uint32_t data)674 bool trie_add_prefix_at_index(trie_t *self, char *key, uint32_t start_node_id, uint32_t data) {
675     size_t len = strlen(key);
676     if (start_node_id == NULL_NODE_ID || len == 0) return false;
677 
678     trie_node_t start_node = trie_get_node(self, start_node_id);
679 
680     unsigned char prefix_char = TRIE_PREFIX_CHAR[0];
681 
682     uint32_t node_id = trie_get_transition_index(self, start_node, prefix_char);
683     trie_node_t node = trie_get_node(self, node_id);
684     if (node.check != start_node_id) {
685         node_id = trie_add_transition(self, start_node_id, prefix_char);
686     }
687 
688     bool success = trie_add_at_index(self, node_id, key, len, data);
689 
690     return success;
691 }
692 
trie_add_prefix(trie_t * self,char * key,uint32_t data)693 inline bool trie_add_prefix(trie_t *self, char *key, uint32_t data) {
694     return trie_add_prefix_at_index(self, key, ROOT_NODE_ID, data);
695 }
696 
trie_add_suffix_at_index(trie_t * self,char * key,uint32_t start_node_id,uint32_t data)697 bool trie_add_suffix_at_index(trie_t *self, char *key, uint32_t start_node_id, uint32_t data) {
698     size_t len = strlen(key);
699     if (start_node_id == NULL_NODE_ID || len == 0) return false;
700 
701     trie_node_t start_node = trie_get_node(self, start_node_id);
702 
703     unsigned char suffix_char = TRIE_SUFFIX_CHAR[0];
704 
705     uint32_t node_id = trie_get_transition_index(self, start_node, suffix_char);
706     trie_node_t node = trie_get_node(self, node_id);
707     if (node.check != start_node_id) {
708         node_id = trie_add_transition(self, start_node_id, suffix_char);
709     }
710 
711     char *suffix = utf8_reversed_string(key);
712 
713     bool success = trie_add_at_index(self, node_id, suffix, len, data);
714 
715     free(suffix);
716     return success;
717 
718 }
719 
trie_add_suffix(trie_t * self,char * key,uint32_t data)720 inline bool trie_add_suffix(trie_t *self, char *key, uint32_t data) {
721     return trie_add_suffix_at_index(self, key, ROOT_NODE_ID, data);
722 }
723 
trie_compare_tail(trie_t * self,char * str,size_t len,size_t tail_index)724 bool trie_compare_tail(trie_t *self, char *str, size_t len, size_t tail_index) {
725     if (tail_index >= self->tail->n) return false;
726 
727     unsigned char *current_tail = self->tail->a + tail_index;
728     return strncmp((char *)current_tail, str, len) == 0;
729 }
730 
trie_get_data_node(trie_t * self,trie_node_t node)731 inline trie_data_node_t trie_get_data_node(trie_t *self, trie_node_t node) {
732     if (node.base >= 0) {
733         return NULL_DATA_NODE;
734     }
735     int32_t data_index = -1*node.base;
736     trie_data_node_t data_node = self->data->a[data_index];
737     return data_node;
738 }
739 
trie_set_data_node(trie_t * self,uint32_t index,trie_data_node_t data_node)740 inline bool trie_set_data_node(trie_t *self, uint32_t index, trie_data_node_t data_node) {
741     if (self == NULL || self->data == NULL || index >= self->data->n) return false;
742     self->data->a[index] = data_node;
743     return true;
744 }
745 
trie_get_data_at_index(trie_t * self,uint32_t index,uint32_t * data)746 inline bool trie_get_data_at_index(trie_t *self, uint32_t index,  uint32_t *data) {
747      if (index == NULL_NODE_ID) return false;
748 
749      trie_node_t node = trie_get_node(self, index);
750      trie_data_node_t data_node = trie_get_data_node(self, node);
751      if (data_node.tail == 0) return false;
752      *data = data_node.data;
753 
754      return true;
755 }
756 
trie_get_data(trie_t * self,char * key,uint32_t * data)757 inline bool trie_get_data(trie_t *self, char *key, uint32_t *data) {
758      uint32_t node_id = trie_get(self, key);
759      return trie_get_data_at_index(self, node_id, data);
760 }
761 
trie_set_data_at_index(trie_t * self,uint32_t index,uint32_t data)762 inline bool trie_set_data_at_index(trie_t *self, uint32_t index, uint32_t data) {
763     if (index == NULL_NODE_ID) return false;
764      trie_node_t node = trie_get_node(self, index);
765      trie_data_node_t data_node = trie_get_data_node(self, node);
766      data_node.data = data;
767      return trie_set_data_node(self, -1*node.base, data_node);
768 
769 }
770 
trie_set_data(trie_t * self,char * key,uint32_t data)771 inline bool trie_set_data(trie_t *self, char *key, uint32_t data) {
772      uint32_t node_id = trie_get(self, key);
773      if (node_id == NULL_NODE_ID) {
774         return trie_add(self, key, data);
775      }
776 
777      return trie_set_data_at_index(self, node_id, data);
778 }
779 
trie_get_prefix_from_index(trie_t * self,char * key,size_t len,uint32_t start_index,size_t tail_pos)780 trie_prefix_result_t trie_get_prefix_from_index(trie_t *self, char *key, size_t len, uint32_t start_index, size_t tail_pos) {
781     if (key == NULL) {
782         return NULL_PREFIX_RESULT;
783     }
784 
785     unsigned char *ptr = (unsigned char *)key;
786 
787     uint32_t node_id = start_index;
788     trie_node_t node = trie_get_node(self, node_id);
789     if (node.base == NULL_NODE_ID) {
790         return NULL_PREFIX_RESULT;
791     }
792 
793     uint32_t next_id = NULL_NODE_ID;
794 
795     bool original_node_no_tail = node.base >= 0;
796 
797     size_t i = 0;
798 
799     if (node.base >= 0) {
800         // Include NUL-byte. It may be stored if this phrase is a prefix of a longer one
801         for (i = 0; i < len; i++, ptr++, node_id = next_id) {
802             next_id = trie_get_transition_index(self, node, *ptr);
803             node = trie_get_node(self, next_id);
804 
805             if (node.check != node_id) {
806                 return NULL_PREFIX_RESULT;
807             }
808 
809             if (node.base < 0) break;
810         }
811     } else {
812         next_id = node_id;
813         node = trie_get_node(self, node_id);
814     }
815 
816     if (node.base < 0) {
817         trie_data_node_t data_node = trie_get_data_node(self, node);
818 
819         char *query_tail = (*ptr && original_node_no_tail) ? (char *)ptr + 1 : (char *)ptr;
820         size_t query_len = (*ptr && original_node_no_tail) ? len - i - 1 : len - i;
821 
822         if (data_node.tail != 0 && trie_compare_tail(self, query_tail, query_len, data_node.tail + tail_pos)) {
823             return (trie_prefix_result_t){next_id, tail_pos + query_len};
824         } else {
825             return NULL_PREFIX_RESULT;
826 
827         }
828     } else {
829         return (trie_prefix_result_t){next_id, 0};
830     }
831 
832     return NULL_PREFIX_RESULT;
833 
834 }
835 
trie_get_prefix_len(trie_t * self,char * key,size_t len)836 trie_prefix_result_t trie_get_prefix_len(trie_t *self, char *key, size_t len) {
837     return trie_get_prefix_from_index(self, key, len, ROOT_NODE_ID, 0);
838 }
839 
trie_get_prefix(trie_t * self,char * key)840 trie_prefix_result_t trie_get_prefix(trie_t *self, char *key) {
841     return trie_get_prefix_from_index(self, key, strlen(key), ROOT_NODE_ID, 0);
842 }
843 
trie_get_from_index(trie_t * self,char * word,size_t len,uint32_t i)844 uint32_t trie_get_from_index(trie_t *self, char *word, size_t len, uint32_t i) {
845     if (word == NULL) return NULL_NODE_ID;
846 
847     unsigned char *ptr = (unsigned char *)word;
848 
849     uint32_t node_id = i;
850     trie_node_t node = trie_get_node(self, i);
851     if (node.base == NULL_NODE_ID) return NULL_NODE_ID;
852 
853     uint32_t next_id;
854 
855     // Include NUL-byte. It may be stored if this phrase is a prefix of a longer one
856 
857     for (size_t i = 0; i < len + 1; i++, ptr++, node_id = next_id) {
858         next_id = trie_get_transition_index(self, node, *ptr);
859         node = trie_get_node(self, next_id);
860 
861         if (node.check != node_id) {
862             return NULL_NODE_ID;
863         }
864 
865         if (node.check == node_id && node.base < 0) {
866             trie_data_node_t data_node = trie_get_data_node(self, node);
867 
868             char *query_tail = *ptr ? (char *) ptr + 1 : (char *) ptr;
869 
870             if (data_node.tail != 0 && trie_compare_tail(self, query_tail, strlen(query_tail) + 1, data_node.tail)) {
871                 return next_id;
872             } else {
873                 return NULL_NODE_ID;
874             }
875 
876         }
877 
878     }
879 
880     return next_id;
881 
882 }
883 
trie_get_len(trie_t * self,char * word,size_t len)884 uint32_t trie_get_len(trie_t *self, char *word, size_t len) {
885     return trie_get_from_index(self, word, len, ROOT_NODE_ID);
886 }
887 
trie_get(trie_t * self,char * word)888 uint32_t trie_get(trie_t *self, char *word) {
889     size_t word_len = strlen(word);
890     return trie_get_from_index(self, word, word_len, ROOT_NODE_ID);
891 }
892 
893 
trie_num_keys(trie_t * self)894 inline uint32_t trie_num_keys(trie_t *self) {
895     if (self == NULL) return 0;
896     return self->num_keys;
897 }
898 
899 /*
900 Destructor
901 */
trie_destroy(trie_t * self)902 void trie_destroy(trie_t *self) {
903     if (!self)
904         return;
905 
906     if (self->alphabet)
907         free(self->alphabet);
908     if (self->nodes)
909         trie_node_array_destroy(self->nodes);
910     if (self->tail)
911         uchar_array_destroy(self->tail);
912     if (self->data)
913         trie_data_array_destroy(self->data);
914 
915     free(self);
916 }
917 
918 
919 /*
920 I/O methods
921 */
922 
trie_write(trie_t * self,FILE * file)923 bool trie_write(trie_t *self, FILE *file) {
924     if (!file_write_uint32(file, TRIE_SIGNATURE) ||
925         !file_write_uint32(file, self->alphabet_size)||
926         !file_write_chars(file, (char *)self->alphabet, (size_t)self->alphabet_size) ||
927         !file_write_uint32(file, self->num_keys) ||
928         !file_write_uint32(file, (uint32_t)self->nodes->n)) {
929         return false;
930     }
931 
932     size_t i;
933     trie_node_t node;
934 
935     for (i = 0; i < self->nodes->n; i++) {
936         node = self->nodes->a[i];
937         if (!file_write_uint32(file, (uint32_t)node.base) ||
938             !file_write_uint32(file, (uint32_t)node.check)) {
939             return false;
940         }
941     }
942 
943     if (!file_write_uint32(file, (uint32_t)self->data->n))
944         return false;
945 
946     trie_data_node_t data_node;
947     for (i = 0; i < self->data->n; i++) {
948         data_node = self->data->a[i];
949         if (!file_write_uint32(file, data_node.tail) ||
950             !file_write_uint32(file, data_node.data)) {
951             return false;
952         }
953     }
954 
955     if (!file_write_uint32(file, (uint32_t)self->tail->n))
956         return false;
957 
958     if (!file_write_chars(file, (char *)self->tail->a, self->tail->n))
959         return false;
960 
961     return true;
962 }
963 
964 
trie_save(trie_t * self,char * path)965 bool trie_save(trie_t *self, char *path) {
966     FILE *file;
967     bool result = false;
968 
969     file = fopen(path, "w+");
970     if (!file)
971         return false;
972 
973     result = trie_write(self, file);
974     fclose(file);
975 
976     return result;
977 }
978 
trie_read(FILE * file)979 trie_t *trie_read(FILE *file) {
980     uint32_t i;
981 
982     long save_pos = ftell(file);
983 
984     uint8_t alphabet[NUM_CHARS];
985 
986     uint32_t signature;
987 
988     if (!file_read_uint32(file, &signature)) {
989         goto exit_file_read;
990     }
991 
992     if (signature != TRIE_SIGNATURE) {
993         goto exit_file_read;
994     }
995 
996     uint32_t alphabet_size;
997 
998     if (!file_read_uint32(file, &alphabet_size)) {
999         goto exit_file_read;
1000     }
1001 
1002     log_debug("alphabet_size=%d\n", alphabet_size);
1003     if (alphabet_size > NUM_CHARS)
1004         goto exit_file_read;
1005 
1006     if (!file_read_chars(file, (char *)alphabet, alphabet_size)) {
1007         goto exit_file_read;
1008     }
1009 
1010     trie_t *trie = trie_new_empty(alphabet, alphabet_size);
1011     if (!trie) {
1012         goto exit_file_read;
1013     }
1014 
1015     uint32_t num_keys;
1016     if (!file_read_uint32(file, &num_keys)) {
1017         goto exit_trie_created;
1018     }
1019 
1020     trie->num_keys = num_keys;
1021 
1022     uint32_t num_nodes;
1023 
1024     if (!file_read_uint32(file, &num_nodes)) {
1025         goto exit_trie_created;
1026     }
1027 
1028     log_debug("num_nodes=%d\n", num_nodes);
1029     trie_node_array_resize(trie->nodes, num_nodes);
1030 
1031     int32_t base;
1032     int32_t check;
1033     trie_node_t node;
1034 
1035     unsigned char *buf;
1036     size_t buf_size = num_nodes * sizeof(uint32_t) * 2;
1037     buf = malloc(buf_size);
1038     if (buf == NULL) {
1039         goto exit_trie_created;
1040     }
1041 
1042     unsigned char *buf_ptr;
1043 
1044     if (file_read_chars(file, (char *)buf, buf_size)) {
1045         buf_ptr = buf;
1046         for (i = 0; i < num_nodes; i++) {
1047             node.base = (int32_t)file_deserialize_uint32(buf_ptr);
1048             buf_ptr += sizeof(uint32_t);
1049             node.check = (int32_t)file_deserialize_uint32(buf_ptr);
1050             buf_ptr += sizeof(uint32_t);
1051 
1052             trie_node_array_push(trie->nodes, node);
1053         }
1054     }
1055 
1056     free(buf);
1057     buf = NULL;
1058 
1059     uint32_t num_data_nodes;
1060     if (!file_read_uint32(file, &num_data_nodes)) {
1061         goto exit_trie_created;
1062     }
1063 
1064     trie_data_array_resize(trie->data, num_data_nodes);
1065     log_debug("num_data_nodes=%d\n", num_data_nodes);
1066 
1067     trie_data_node_t data_node;
1068 
1069     buf_size = num_data_nodes * sizeof(uint32_t) * 2;
1070     buf = malloc(buf_size);
1071     if (buf == NULL) {
1072         goto exit_trie_created;
1073     }
1074 
1075     if (file_read_chars(file, (char *)buf, buf_size)) {
1076         buf_ptr = buf;
1077         for (i = 0; i < num_data_nodes; i++) {
1078             data_node.tail = (int32_t)file_deserialize_uint32(buf_ptr);
1079             buf_ptr += sizeof(uint32_t);
1080             data_node.data = (int32_t)file_deserialize_uint32(buf_ptr);
1081             buf_ptr += sizeof(uint32_t);
1082 
1083             trie_data_array_push(trie->data, data_node);
1084         }
1085     }
1086 
1087     free(buf);
1088 
1089     uint32_t tail_len;
1090     if (!file_read_uint32(file, &tail_len)) {
1091         goto exit_trie_created;
1092     }
1093 
1094     uchar_array_resize(trie->tail, tail_len);
1095     trie->tail->n = tail_len;
1096 
1097     if (!file_read_chars(file, (char *)trie->tail->a, tail_len)) {
1098         goto exit_trie_created;
1099     }
1100 
1101     return trie;
1102 
1103 exit_trie_created:
1104     trie_destroy(trie);
1105 exit_file_read:
1106     fseek(file, save_pos, SEEK_SET);
1107     return NULL;
1108 }
1109 
trie_load(char * path)1110 trie_t *trie_load(char *path) {
1111     FILE *file;
1112 
1113     file = fopen(path, "rb");
1114     if (!file)
1115         return NULL;
1116 
1117     trie_t *trie = trie_read(file);
1118 
1119     fclose(file);
1120 
1121     return trie;
1122 }
1123