1 #include "trie_search.h"
2 
3 typedef enum {
4     SEARCH_STATE_BEGIN,
5     SEARCH_STATE_NO_MATCH,
6     SEARCH_STATE_PARTIAL_MATCH,
7     SEARCH_STATE_MATCH
8 } trie_search_state_t;
9 
trie_search_from_index(trie_t * self,char * text,uint32_t start_node_id,phrase_array ** phrases)10 bool trie_search_from_index(trie_t *self, char *text, uint32_t start_node_id, phrase_array **phrases) {
11     if (text == NULL) return false;
12 
13     ssize_t len, remaining;
14     int32_t unich = 0;
15     unsigned char ch = '\0';
16 
17     const uint8_t *ptr = (const uint8_t *)text;
18     const uint8_t *fail_ptr = ptr;
19 
20     uint32_t node_id = start_node_id;
21     trie_node_t node = trie_get_node(self, node_id), last_node = node;
22     uint32_t next_id;
23 
24     bool match = false;
25     uint32_t index = 0;
26     uint32_t phrase_len = 0;
27     uint32_t phrase_start = 0;
28     uint32_t data;
29 
30     trie_search_state_t state = SEARCH_STATE_BEGIN, last_state = SEARCH_STATE_BEGIN;
31 
32     bool advance_index = true;
33 
34     while(1) {
35         len = utf8proc_iterate(ptr, -1, &unich);
36         remaining = len;
37         if (len <= 0) return false;
38         if (!(utf8proc_codepoint_valid(unich))) return false;
39 
40         int cat = utf8proc_category(unich);
41         bool is_letter = utf8_is_letter(cat);
42 
43         // If we're in the middle of a word and the first letter was not a match, skip the word
44         if (is_letter && state == SEARCH_STATE_NO_MATCH) {
45             log_debug("skipping\n");
46             ptr += len;
47             index += len;
48             last_state = state;
49             continue;
50         }
51 
52         // Match in the middle of a word
53         if (is_letter && last_state == SEARCH_STATE_MATCH) {
54             log_debug("last_state == SEARCH_STATE_MATCH && is_letter\n");
55             // Only set match to false so we don't callback
56             match = false;
57         }
58 
59         for (int i=0; remaining > 0; remaining--, i++, ptr++, last_node=node, last_state=state, node_id=next_id) {
60             ch = (unsigned char) *ptr;
61             log_debug("char=%c\n", ch);
62 
63             next_id = trie_get_transition_index(self, node, *ptr);
64             node = trie_get_node(self, next_id);
65 
66             if (node.check != node_id) {
67                 state = is_letter ? SEARCH_STATE_NO_MATCH : SEARCH_STATE_BEGIN;
68                 if (match) {
69                     log_debug("match is true and state==SEARCH_STATE_NO_MATCH\n");
70                     if (*phrases == NULL) {
71                         *phrases = phrase_array_new_size(1);
72                     }
73                     phrase_array_push(*phrases, (phrase_t){phrase_start, phrase_len, data});
74                     index = phrase_start + phrase_len;
75                     advance_index = false;
76                     // Set the text back to the end of the last phrase
77                     ptr = (const uint8_t *)text + index;
78                     len = utf8proc_iterate(ptr, -1, &unich);
79                     log_debug("ptr=%s\n", ptr);
80                 } else {
81                     ptr += remaining;
82                     log_debug("done with char, now at %s\n", ptr);
83                 }
84                 fail_ptr = ptr;
85                 node_id = start_node_id;
86                 last_node = node = trie_get_node(self, node_id);
87                 phrase_start = phrase_len = 0;
88                 last_state = state;
89                 match = false;
90                 break;
91             } else {
92                 log_debug("node.check == node_id\n");
93                 state = SEARCH_STATE_PARTIAL_MATCH;
94                 if (last_state == SEARCH_STATE_NO_MATCH || last_state == SEARCH_STATE_BEGIN) {
95                     log_debug("phrase_start=%u\n", index);
96                     phrase_start = index;
97                     fail_ptr = ptr + remaining;
98                 }
99 
100                 if (node.base < 0) {
101                     int32_t data_index = -1*node.base;
102                     trie_data_node_t data_node = self->data->a[data_index];
103                     unsigned char *current_tail = self->tail->a + data_node.tail;
104 
105                     size_t tail_len = strlen((char *)current_tail);
106                     char *query_tail = (char *)(*ptr ? ptr + 1 : ptr);
107                     size_t query_tail_len = strlen((char *)query_tail);
108                     log_debug("next node tail: %s\n", current_tail);
109                     log_debug("query node tail: %s\n", query_tail);
110 
111                     if (tail_len <= query_tail_len && strncmp((char *)current_tail, query_tail, tail_len) == 0) {
112                         state = SEARCH_STATE_MATCH;
113                         log_debug("Tail matches\n");
114                         last_state = state;
115                         data = data_node.data;
116                         log_debug("%u, %d, %zu\n", index, phrase_len, tail_len);
117                         ptr += tail_len;
118                         index += tail_len;
119                         advance_index = false;
120                         phrase_len = index + 1 - phrase_start;
121                         match = true;
122                     } else if (match) {
123                         log_debug("match is true and longer phrase tail did not match\n");
124                         log_debug("phrase_start=%d, phrase_len=%d\n", phrase_start, phrase_len);
125                         if (*phrases == NULL) {
126                             *phrases = phrase_array_new_size(1);
127                         }
128                         phrase_array_push(*phrases, (phrase_t){phrase_start, phrase_len, data});
129                         ptr = fail_ptr;
130                         match = false;
131                         index = phrase_start + phrase_len;
132                         advance_index = false;
133                     }
134 
135                 }
136 
137                 if (ch != '\0') {
138                     trie_node_t terminal_node = trie_get_transition(self, node, '\0');
139                     if (terminal_node.check == next_id) {
140                         log_debug("Transition to NUL byte matched\n");
141                         state = SEARCH_STATE_MATCH;
142                         match = true;
143                         phrase_len = index + (uint32_t)len - phrase_start;
144                         if (terminal_node.base < 0) {
145                             int32_t data_index = -1*terminal_node.base;
146                             trie_data_node_t data_node = self->data->a[data_index];
147                             data = data_node.data;
148                         }
149                         log_debug("Got match with len=%d\n", phrase_len);
150                         fail_ptr = ptr;
151                     }
152                 }
153             }
154 
155         }
156 
157         if (unich == 0) {
158             if (last_state == SEARCH_STATE_MATCH) {
159                 log_debug("Found match at the end\n");
160                 if (*phrases == NULL) {
161                     *phrases = phrase_array_new_size(1);
162                 }
163                 phrase_array_push(*phrases, (phrase_t){phrase_start, phrase_len, data});
164             }
165             break;
166         }
167 
168         if (advance_index) index += len;
169 
170         advance_index = true;
171         log_debug("index now %u\n", index);
172     } // while
173 
174     return true;
175 }
176 
trie_search_with_phrases(trie_t * self,char * str,phrase_array ** phrases)177 inline bool trie_search_with_phrases(trie_t *self, char *str, phrase_array **phrases) {
178     return trie_search_from_index(self, str, ROOT_NODE_ID, phrases);
179 }
180 
trie_search(trie_t * self,char * text)181 inline phrase_array *trie_search(trie_t *self, char *text) {
182     phrase_array *phrases = NULL;
183     if (!trie_search_with_phrases(self, text, &phrases)) {
184         return false;
185     }
186     return phrases;
187 }
188 
trie_node_search_tail_tokens(trie_t * self,trie_node_t node,char * str,token_array * tokens,size_t tail_index,int token_index)189 int trie_node_search_tail_tokens(trie_t *self, trie_node_t node, char *str, token_array *tokens, size_t tail_index, int token_index) {
190     int32_t data_index = -1*node.base;
191     trie_data_node_t old_data_node = self->data->a[data_index];
192     uint32_t current_tail_pos = old_data_node.tail;
193 
194     log_debug("tail_index = %zu\n", tail_index);
195 
196     unsigned char *tail_ptr = self->tail->a + current_tail_pos + tail_index;
197 
198     if (!(*tail_ptr)) {
199         log_debug("tail matches!\n");
200         return token_index - 1;
201     }
202 
203     log_debug("Searching tail: %s\n", tail_ptr);
204     size_t num_tokens = tokens->n;
205     for (int i = token_index; i < num_tokens; i++) {
206         token_t token = tokens->a[i];
207 
208         char *ptr = str + token.offset;
209         size_t token_length = token.len;
210 
211         if (!(*tail_ptr)) {
212             log_debug("tail matches!\n");
213             return i - 1;
214         }
215 
216         if (token.type == WHITESPACE && *tail_ptr == ' ') continue;
217 
218         if (*tail_ptr == ' ') {
219             tail_ptr++;
220             log_debug("Got space, advancing pointer, tail_ptr=%s\n", tail_ptr);
221         }
222 
223         log_debug("Tail string compare: %s with %.*s\n", tail_ptr, (int)token_length, ptr);
224 
225         if (strncmp((char *)tail_ptr, ptr, token_length) == 0) {
226             tail_ptr += token_length;
227 
228             if (i == num_tokens - 1 && !(*tail_ptr)) {
229                 return i;
230             }
231         } else {
232             return -1;
233         }
234     }
235     return -1;
236 
237 }
238 
239 
trie_search_tokens_from_index(trie_t * self,char * str,token_array * tokens,uint32_t start_node_id,phrase_array ** phrases)240 bool trie_search_tokens_from_index(trie_t *self, char *str, token_array *tokens, uint32_t start_node_id, phrase_array **phrases) {
241     if (str == NULL || tokens == NULL || tokens->n == 0) return false;
242 
243     uint32_t node_id = start_node_id, last_node_id = start_node_id;
244     trie_node_t node = trie_get_node(self, node_id), last_node = node;
245 
246     uint32_t data;
247 
248     int phrase_len = 0, phrase_start = 0, last_match_index = -1;
249 
250     trie_search_state_t state = SEARCH_STATE_BEGIN, last_state = SEARCH_STATE_BEGIN;
251 
252     token_t token;
253     size_t token_length;
254 
255     log_debug("num_tokens: %zu\n", tokens->n);
256     for (int i = 0; i < tokens->n; i++, last_state = state) {
257         token = tokens->a[i];
258         token_length = token.len;
259 
260         char *ptr = str + token.offset;
261         log_debug("On %d, token=%.*s\n", i, (int)token_length, ptr);
262 
263         bool check_continuation = true;
264 
265         if (token.type != WHITESPACE) {
266             for (int j = 0; j < token_length; j++, ptr++, last_node = node, last_node_id = node_id) {
267                 log_debug("Getting transition index for %d, (%d, %d)\n", node_id, node.base, node.check);
268                 size_t offset = j + 1;
269                 if (j > 0 || last_node.base >= 0) {
270                     node_id = trie_get_transition_index(self, node, *ptr);
271                     node = trie_get_node(self, node_id);
272                     log_debug("Doing %c, got node_id=%d\n", *ptr, node_id);
273                 } else {
274                     log_debug("Tail stored on space node, rolling back one character\n");
275                     ptr--;
276                     offset = j;
277                     log_debug("ptr=%s\n", ptr);
278                 }
279 
280                 if (node.check != last_node_id && last_node.base >= 0) {
281                     log_debug("Fell off trie. last_node_id=%d and node.check=%d\n", last_node_id, node.check);
282                     node_id = last_node_id = start_node_id;
283                     node = last_node = trie_get_node(self, node_id);
284                     break;
285                 } else if (node.base < 0) {
286                     log_debug("Searching tail at index %d\n", i);
287 
288                     uint32_t data_index = -1*node.base;
289                     trie_data_node_t data_node = self->data->a[data_index];
290                     uint32_t current_tail_pos = data_node.tail;
291 
292                     unsigned char *current_tail = self->tail->a + current_tail_pos;
293 
294                     log_debug("token_length = %zu, j=%d\n", token_length, j);
295 
296                     size_t ptr_len = token_length - offset;
297                     log_debug("next node tail: %s vs %.*s\n", current_tail, (int)ptr_len, ptr + 1);
298 
299                     if (last_state == SEARCH_STATE_NO_MATCH || last_state == SEARCH_STATE_BEGIN) {
300                         log_debug("phrase start at %d\n", i);
301                         phrase_start = i;
302                     }
303                     if (strncmp((char *)current_tail, ptr + 1, ptr_len) == 0) {
304                         log_debug("node tail matches first token\n");
305                         int tail_search_result = trie_node_search_tail_tokens(self, node, str, tokens, ptr_len, i + 1);
306                         log_debug("tail_search_result=%d\n", tail_search_result);
307                         node_id = start_node_id;
308                         node = trie_get_node(self, node_id);
309                         check_continuation = false;
310 
311                         if (tail_search_result != -1) {
312                             phrase_len = tail_search_result - phrase_start + 1;
313                             last_match_index = i = tail_search_result;
314                             last_state = SEARCH_STATE_MATCH;
315                             data = data_node.data;
316                         }
317                         break;
318 
319                     } else {
320                         node_id = start_node_id;
321                         node = trie_get_node(self, node_id);
322                         break;
323                     }
324                 }
325             }
326         } else {
327             check_continuation = false;
328             if (state == SEARCH_STATE_BEGIN || state == SEARCH_STATE_NO_MATCH) {
329                 continue;
330             }
331         }
332 
333 
334         if (node.check <= 0 || node_id == start_node_id) {
335             log_debug("state = SEARCH_STATE_NO_MATCH\n");
336             state = SEARCH_STATE_NO_MATCH;
337             // check
338             if (last_match_index != -1) {
339                 log_debug("last_match not NULL and state==SEARCH_STATE_NO_MATCH, data=%d\n", data);
340                 if (*phrases == NULL) {
341                     *phrases = phrase_array_new_size(1);
342                 }
343                 phrase_array_push(*phrases, (phrase_t){phrase_start, last_match_index - phrase_start + 1, data});
344                 i = last_match_index;
345                 last_match_index = -1;
346                 phrase_start = phrase_len = 0;
347                 node_id = last_node_id = start_node_id;
348                 node = last_node = trie_get_node(self, start_node_id);
349                 continue;
350             } else if (last_state == SEARCH_STATE_PARTIAL_MATCH) {
351                 // we're in the middle of a phrase that did not fully
352                 // match the trie. Imagine a trie that contains "a b c" and "b"
353                 // and our string is "a b d". When we get to "d", we've matched
354                 // the prefix "a b" and then will fall off the trie. So we'll actually
355                 // need to go back to token "b" and start searching from the root
356 
357                 // Note: i gets incremented with the next iteration of the for loop
358                 i = phrase_start;
359                 log_debug("last_state == SEARCH_STATE_PARTIAL_MATCH, i = %d\n", i);
360                 last_match_index = -1;
361                 phrase_start = phrase_len = 0;
362                 node_id = last_node_id = start_node_id;
363                 node = last_node = trie_get_node(self, start_node_id);
364                 continue;
365             } else {
366                 phrase_start = phrase_len = 0;
367                 // this token was not a phrase
368                 log_debug("Plain token=%.*s\n", (int)token.len, str + token.offset);
369             }
370             node_id = last_node_id = start_node_id;
371             node = last_node = trie_get_node(self, start_node_id);
372         } else {
373 
374             state = SEARCH_STATE_PARTIAL_MATCH;
375             if (!(node.base < 0) && (last_state == SEARCH_STATE_NO_MATCH || last_state == SEARCH_STATE_BEGIN)) {
376                 log_debug("phrase_start=%d, node.base = %d, last_state=%d\n", i, node.base, last_state);
377                 phrase_start = i;
378             }
379 
380             trie_node_t terminal_node = trie_get_transition(self, node, '\0');
381             if (terminal_node.check == node_id) {
382                 log_debug("node match at %d\n", i);
383                 state = SEARCH_STATE_MATCH;
384                 int32_t data_index = -1*terminal_node.base;
385                 trie_data_node_t data_node = self->data->a[data_index];
386                 data = data_node.data;
387                 log_debug("data = %d\n", data);
388 
389                 log_debug("phrase_start = %d\n", phrase_start);
390 
391                 last_match_index = i;
392                 log_debug("last_match_index = %d\n", i);
393             }
394 
395             if (i == tokens->n - 1) {
396                 if (last_match_index == -1) {
397                     log_debug("At last token\n");
398                     break;
399                 } else {
400                     if (*phrases == NULL) {
401                         *phrases = phrase_array_new_size(1);
402                     }
403                     phrase_array_push(*phrases, (phrase_t){phrase_start, last_match_index - phrase_start + 1, data});
404                     i = last_match_index;
405                     last_match_index = -1;
406                     phrase_start = phrase_len = 0;
407                     node_id = last_node_id = start_node_id;
408                     node = last_node = trie_get_node(self, start_node_id);
409                     state = SEARCH_STATE_NO_MATCH;
410                     continue;
411                 }
412             }
413 
414             if (check_continuation) {
415 
416                 // Check continuation
417                 uint32_t continuation_id = trie_get_transition_index(self, node, ' ');
418                 log_debug("transition_id: %u\n", continuation_id);
419                 trie_node_t continuation = trie_get_node(self, continuation_id);
420 
421                 if (token.type == IDEOGRAPHIC_CHAR && continuation.check != node_id) {
422                     log_debug("Ideographic character\n");
423                     last_node_id = node_id;
424                     last_node = node;
425                 } else if (continuation.check != node_id && last_match_index != -1) {
426                     log_debug("node->match no continuation\n");
427                     if (*phrases == NULL) {
428                         *phrases = phrase_array_new_size(1);
429                     }
430                     phrase_array_push(*phrases, (phrase_t){phrase_start, last_match_index - phrase_start + 1, data});
431                     i = last_match_index;
432                     last_match_index = -1;
433                     phrase_start = phrase_len = 0;
434                     node_id = last_node_id = start_node_id;
435                     node = last_node = trie_get_node(self, start_node_id);
436                     state = SEARCH_STATE_BEGIN;
437                 } else if (continuation.check != node_id) {
438                     log_debug("No continuation for phrase with start=%d, yielding tokens\n", phrase_start);
439                     state = SEARCH_STATE_NO_MATCH;
440                     phrase_start = phrase_len = 0;
441                     node_id = last_node_id = start_node_id;
442                     node = last_node = trie_get_node(self, start_node_id);
443                 } else {
444                     log_debug("Has continuation, node_id=%d\n", continuation_id);
445                     last_node = node = continuation;
446                     last_node_id = node_id = continuation_id;
447                 }
448             }
449         }
450 
451     }
452 
453     if (last_match_index != -1) {
454         if (*phrases == NULL) {
455             *phrases = phrase_array_new_size(1);
456         }
457         log_debug("adding phrase, last_match_index=%d\n", last_match_index);
458         phrase_array_push(*phrases, (phrase_t){phrase_start, last_match_index - phrase_start + 1, data});
459    }
460 
461     return true;
462 }
463 
trie_search_tokens_with_phrases(trie_t * self,char * str,token_array * tokens,phrase_array ** phrases)464 inline bool trie_search_tokens_with_phrases(trie_t *self, char *str, token_array *tokens, phrase_array **phrases) {
465     return trie_search_tokens_from_index(self, str, tokens, ROOT_NODE_ID, phrases);
466 }
467 
trie_search_tokens(trie_t * self,char * str,token_array * tokens)468 inline phrase_array *trie_search_tokens(trie_t *self, char *str, token_array *tokens) {
469     phrase_array *phrases = NULL;
470     if (!trie_search_tokens_with_phrases(self, str, tokens, &phrases)) {
471         return NULL;
472     }
473     return phrases;
474 }
475 
trie_search_suffixes_from_index(trie_t * self,char * word,size_t len,uint32_t start_node_id)476 phrase_t trie_search_suffixes_from_index(trie_t *self, char *word, size_t len, uint32_t start_node_id) {
477     uint32_t last_node_id = start_node_id;
478     trie_node_t last_node = trie_get_node(self, last_node_id);
479     uint32_t node_id = last_node_id;
480     trie_node_t node = last_node;
481 
482     uint32_t value = 0, phrase_start = 0, phrase_len = 0;
483 
484     ssize_t char_len;
485 
486     int32_t unich = 0;
487 
488     ssize_t index = len;
489     const uint8_t *ptr = (const uint8_t *)word;
490     const uint8_t *char_ptr;
491 
492     bool in_tail = false;
493     unsigned char *current_tail = (unsigned char *)"";
494     size_t tail_remaining = 0;
495 
496     uint32_t tail_value = 0;
497 
498     while(index > 0) {
499         char_len = utf8proc_iterate_reversed(ptr, index, &unich);
500 
501         if (char_len <= 0) return NULL_PHRASE;
502         if (!(utf8proc_codepoint_valid(unich))) return NULL_PHRASE;
503 
504         index -= char_len;
505         char_ptr = ptr + index;
506 
507         if (in_tail && tail_remaining >= char_len && strncmp((char *)current_tail, (char *)char_ptr, char_len) == 0) {
508             tail_remaining -= char_len;
509             current_tail += char_len;
510             phrase_start = (uint32_t)index;
511 
512             log_debug("tail matched at char %.*s (len=%zd)\n", (int)char_len, char_ptr, char_len);
513             log_debug("tail_remaining = %zu\n", tail_remaining);
514 
515             if (tail_remaining == 0) {
516                 log_debug("tail match! tail_value=%u\n",tail_value);
517                 phrase_len = (uint32_t)(len - index);
518                 value = tail_value;
519                 index = 0;
520                 break;
521             }
522             continue;
523         } else if (in_tail) {
524             break;
525         }
526 
527         for (int i=0; i < char_len; i++, char_ptr++, last_node = node, last_node_id = node_id) {
528             log_debug("char=%c\n", (unsigned char)*char_ptr);
529 
530             node_id = trie_get_transition_index(self, node, *char_ptr);
531             node = trie_get_node(self, node_id);
532 
533             if (node.check != last_node_id) {
534                 log_debug("node.check = %d and last_node_id = %d\n", node.check, last_node_id);
535                 index = 0;
536                 break;
537             } else if (node.base < 0) {
538                 log_debug("Searching tail\n");
539 
540                 uint32_t data_index = -1*node.base;
541                 trie_data_node_t data_node = self->data->a[data_index];
542                 uint32_t current_tail_pos = data_node.tail;
543 
544                 tail_value = data_node.data;
545 
546                 current_tail = self->tail->a + current_tail_pos;
547 
548                 tail_remaining = strlen((char *)current_tail);
549                 log_debug("tail_remaining=%zu\n", tail_remaining);
550                 in_tail = true;
551 
552                 size_t remaining_char_len = char_len - i - 1;
553                 log_debug("remaining_char_len = %zu\n", remaining_char_len);
554 
555                 if (remaining_char_len > 0 && strncmp((char *)char_ptr + 1, (char *)current_tail, remaining_char_len) == 0) {
556                     log_debug("tail string comparison successful\n");
557                     tail_remaining -= remaining_char_len;
558                     current_tail += remaining_char_len;
559                 } else if (remaining_char_len > 0) {
560                     log_debug("tail comparison unsuccessful, %s vs %s\n", char_ptr, current_tail);
561                     index = 0;
562                     break;
563                 }
564 
565                 if (tail_remaining == 0) {
566                     phrase_start = (uint32_t)index;
567                     phrase_len = (uint32_t)(len - index);
568                     log_debug("phrase_start = %d, phrase_len=%d\n", phrase_start, phrase_len);
569                     value = tail_value;
570                     index = 0;
571                 }
572                 break;
573             } else if (i == char_len - 1) {
574                 trie_node_t terminal_node = trie_get_transition(self, node, '\0');
575                 if (terminal_node.check == node_id) {
576                     int32_t data_index = -1 * terminal_node.base;
577                     trie_data_node_t data_node = self->data->a[data_index];
578                     value = data_node.data;
579                     phrase_start = (uint32_t)index;
580                     phrase_len = (uint32_t)(len - index);
581                 }
582             }
583 
584         }
585 
586     }
587 
588     return (phrase_t) {phrase_start, phrase_len, value};
589 }
590 
trie_search_suffixes_from_index_get_suffix_char(trie_t * self,char * word,size_t len,uint32_t start_node_id)591 inline phrase_t trie_search_suffixes_from_index_get_suffix_char(trie_t *self, char *word, size_t len, uint32_t start_node_id) {
592     if (word == NULL || len == 0) return NULL_PHRASE;
593     trie_node_t node = trie_get_node(self, start_node_id);
594     unsigned char suffix_char = TRIE_SUFFIX_CHAR[0];
595     uint32_t node_id = trie_get_transition_index(self, node, suffix_char);
596     node = trie_get_node(self, node_id);
597 
598     if (node.check != start_node_id) {
599         log_debug("node.check != start_node_id\n");
600         return NULL_PHRASE;
601     }
602 
603     return trie_search_suffixes_from_index(self, word, len, node_id);
604 }
605 
trie_search_suffixes(trie_t * self,char * word,size_t len)606 inline phrase_t trie_search_suffixes(trie_t *self, char *word, size_t len) {
607     if (word == NULL || len == 0) return NULL_PHRASE;
608     return trie_search_suffixes_from_index_get_suffix_char(self, word, len, ROOT_NODE_ID);
609 }
610 
611 
trie_search_prefixes_from_index(trie_t * self,char * word,size_t len,uint32_t start_node_id)612 phrase_t trie_search_prefixes_from_index(trie_t *self, char *word, size_t len, uint32_t start_node_id) {
613     log_debug("Call to trie_search_prefixes_from_index\n");
614     uint32_t node_id = start_node_id, last_node_id = node_id;
615     trie_node_t node = trie_get_node(self, node_id), last_node = node;
616 
617     log_debug("last_node_id = %d\n", last_node_id);
618 
619     uint32_t value = 0, phrase_start = 0, phrase_len = 0;
620 
621     uint8_t *ptr = (uint8_t *)word;
622 
623     ssize_t char_len = 0;
624 
625     uint32_t idx = 0;
626 
627     size_t separator_char_len = 0;
628 
629     int32_t codepoint = 0;
630 
631     bool first_char = true;
632 
633     trie_data_node_t data_node;
634     trie_node_t terminal_node;
635 
636     bool phrase_at_hyphen = false;
637 
638     while (idx < len) {
639         char_len = utf8proc_iterate(ptr, len, &codepoint);
640         log_debug("char_len = %zu, char=%d\n", char_len, codepoint);
641         if (char_len <= 0) break;
642         if (!(utf8proc_codepoint_valid(codepoint))) break;
643 
644         bool is_hyphen = utf8_is_hyphen(codepoint);
645 
646         int cat = utf8proc_category(codepoint);
647         bool is_space = utf8_is_separator(cat);
648 
649         uint8_t *char_ptr = ptr;
650         size_t i = 0;
651 
652         bool skip_char = false;
653         bool break_out = false;
654 
655         for (i = 0; i < char_len; i++) {
656             node_id = trie_get_transition_index(self, last_node, *char_ptr);
657             node = trie_get_node(self, node_id);
658             log_debug("At idx=%u, i=%zu, char=%.*s\n", idx, i, (int)char_len, char_ptr);
659 
660             if (node.check != last_node_id) {
661                 log_debug("node.check = %d and last_node_id = %d\n", node.check, last_node_id);
662 
663                 if (is_hyphen || (is_space && *ptr != ' ')) {
664                     log_debug("Got hyphen or other separator, trying space instead\n");
665                     node_id = trie_get_transition_index(self, last_node, ' ');
666                     node = trie_get_node(self, node_id);
667                 }
668 
669                 if (is_hyphen && node.check != last_node_id) {
670                     log_debug("No space transition, phrase_len=%zu\n", phrase_len);
671                     if (phrase_len > 0 && phrase_len == idx) {
672                         log_debug("phrase_at_hyphen\n");
673                         phrase_at_hyphen = true;
674                     }
675 
676                     ptr += char_len;
677                     idx += char_len;
678                     separator_char_len = char_len;
679                     node_id = last_node_id;
680                     node = trie_get_node(self, node_id);
681                     skip_char = true;
682                     break;
683                 } else if (node.check != last_node_id) {
684                     break_out = true;
685                     log_debug("Breaking\n");
686                     break;
687                 }
688                 break;
689             }
690 
691             if (first_char) {
692                 phrase_start = idx;
693                 first_char = false;
694             }
695 
696             if (node.base < 0) {
697                 log_debug("Searching tail\n");
698 
699                 data_node = trie_get_data_node(self, node);
700                 uint32_t current_tail_pos = data_node.tail;
701 
702                 unsigned char *current_tail = self->tail->a + current_tail_pos;
703 
704                 log_debug("comparing tail: %s vs %s\n", current_tail, char_ptr + 1);
705                 size_t current_tail_len = strlen((char *)current_tail);
706 
707                 size_t match_len = i + 1;
708                 size_t offset = i + 1;
709                 size_t tail_pos = 0;
710                 log_debug("offset=%zu\n", offset);
711 
712                 if (char_len > 1) {
713                     log_debug("char_len = %zu\n", char_len);
714                     log_debug("Doing strncmp: (%zu) %s vs %s\n", char_len - offset, current_tail, char_ptr + 1);
715 
716                     if (strncmp((char *)ptr + offset, (char *)current_tail, char_len - offset) == 0) {
717                         match_len += char_len - offset;
718                         tail_pos = char_len - offset;
719                         log_debug("in char match_len = %zu\n", match_len);
720                     } else {
721                         return NULL_PHRASE;
722                     }
723                 }
724 
725                 size_t tail_match_len = utf8_common_prefix_len((char *)ptr + char_len, (char *)current_tail + tail_pos, current_tail_len - tail_pos);
726                 match_len += tail_match_len;
727                 log_debug("match_len=%zu\n", match_len);
728 
729                 if (tail_match_len == current_tail_len - tail_pos) {
730                     if (phrase_at_hyphen) {
731                         char_len = utf8proc_iterate(ptr + char_len, len, &codepoint);
732                         if (char_len > 0 && utf8proc_codepoint_valid(codepoint)) {
733                             int cat = utf8proc_category(codepoint);
734 
735                             if (codepoint != 0 && !utf8_is_hyphen(codepoint) && !utf8_is_separator(cat) && !utf8_is_punctuation(cat)) {
736                                 return (phrase_t){phrase_start, phrase_len, value};
737                             }
738                         }
739                     }
740                     if (first_char) phrase_start = idx;
741                     phrase_len = (uint32_t)(idx + match_len) - phrase_start;
742 
743                     log_debug("tail match! phrase_len=%u, len=%zu\n", phrase_len, len);
744                     value = data_node.data;
745                     return (phrase_t){phrase_start, phrase_len, value};
746                 } else {
747                     return NULL_PHRASE;
748                 }
749 
750             } else if (node.check == last_node_id) {
751                 terminal_node = trie_get_transition(self, node, '\0');
752                 log_debug("Trying link from %d to terminal node\n", last_node_id);
753 
754                 if (terminal_node.check == node_id) {
755                     log_debug("Transition to NUL byte matched\n");
756                     if (terminal_node.base < 0) {
757                         phrase_len = (uint32_t)(idx + char_len) - phrase_start;
758                         data_node = trie_get_data_node(self, terminal_node);
759                         value = data_node.data;
760                     }
761                     log_debug("Got match with len=%d\n", phrase_len);
762                 }
763             }
764 
765             last_node = node;
766             last_node_id = node_id;
767             log_debug("last_node_id = %d\n", last_node_id);
768             char_ptr++;
769         }
770 
771 
772         if (break_out) {
773             break;
774         } else if (skip_char) {
775             continue;
776         }
777 
778         log_debug("Incrementing index\n");
779 
780         idx += char_len;
781         ptr += char_len;
782     }
783 
784     log_debug("exited while loop\n");
785 
786     if (phrase_len == 0) return NULL_PHRASE;
787 
788     return (phrase_t) {phrase_start, phrase_len, value};
789 }
790 
trie_search_prefixes_from_index_get_prefix_char(trie_t * self,char * word,size_t len,uint32_t start_node_id)791 inline phrase_t trie_search_prefixes_from_index_get_prefix_char(trie_t *self, char *word, size_t len, uint32_t start_node_id) {
792     trie_node_t node = trie_get_node(self, start_node_id);
793     unsigned char prefix_char = TRIE_PREFIX_CHAR[0];
794     uint32_t node_id = trie_get_transition_index(self, node, prefix_char);
795     node = trie_get_node(self, node_id);
796 
797     if (node.check != start_node_id) {
798         return NULL_PHRASE;
799     }
800 
801     return trie_search_prefixes_from_index(self, word, len, node_id);
802 }
803 
trie_search_prefixes(trie_t * self,char * word,size_t len)804 inline phrase_t trie_search_prefixes(trie_t *self, char *word, size_t len) {
805     if (word == NULL || len == 0) return NULL_PHRASE;
806     return trie_search_prefixes_from_index_get_prefix_char(self, word, len, ROOT_NODE_ID);
807 }
808 
token_phrase_memberships(phrase_array * phrases,int64_array * phrase_memberships,size_t len)809 bool token_phrase_memberships(phrase_array *phrases, int64_array *phrase_memberships, size_t len) {
810     if (phrases == NULL || phrase_memberships == NULL) {
811         return false;
812     }
813 
814     int64_t i = 0;
815     for (int64_t j = 0; j < phrases->n; j++) {
816         phrase_t phrase = phrases->a[j];
817 
818         for (; i < phrase.start; i++) {
819             int64_array_push(phrase_memberships, NULL_PHRASE_MEMBERSHIP);
820             log_debug("token i=%" PRId64 ", null phrase membership\n", i);
821         }
822 
823         for (i = phrase.start; i < phrase.start + phrase.len; i++) {
824             log_debug("token i=%" PRId64 ", phrase membership=%" PRId64 "\n", i, j);
825             int64_array_push(phrase_memberships, j);
826         }
827     }
828 
829     for (; i < len; i++) {
830         log_debug("token i=%" PRId64 ", null phrase membership\n", i);
831         int64_array_push(phrase_memberships, NULL_PHRASE_MEMBERSHIP);
832     }
833 
834     return true;
835 }
836 
cstring_array_get_phrase(cstring_array * str,char_array * phrase_tokens,phrase_t phrase)837 inline char *cstring_array_get_phrase(cstring_array *str, char_array *phrase_tokens, phrase_t phrase) {
838     char_array_clear(phrase_tokens);
839 
840     size_t phrase_end = phrase.start + phrase.len;
841 
842     for (int k = phrase.start; k < phrase_end; k++) {
843         char *w = cstring_array_get_string(str, k);
844         char_array_append(phrase_tokens, w);
845         if (k < phrase_end - 1) {
846             char_array_append(phrase_tokens, " ");
847         }
848     }
849     char_array_terminate(phrase_tokens);
850 
851     return char_array_get_string(phrase_tokens);
852 }
853 
854