1 #include "trie_search.h"
2
3 typedef enum {
4 SEARCH_STATE_BEGIN,
5 SEARCH_STATE_NO_MATCH,
6 SEARCH_STATE_PARTIAL_MATCH,
7 SEARCH_STATE_MATCH
8 } trie_search_state_t;
9
trie_search_from_index(trie_t * self,char * text,uint32_t start_node_id,phrase_array ** phrases)10 bool trie_search_from_index(trie_t *self, char *text, uint32_t start_node_id, phrase_array **phrases) {
11 if (text == NULL) return false;
12
13 ssize_t len, remaining;
14 int32_t unich = 0;
15 unsigned char ch = '\0';
16
17 const uint8_t *ptr = (const uint8_t *)text;
18 const uint8_t *fail_ptr = ptr;
19
20 uint32_t node_id = start_node_id;
21 trie_node_t node = trie_get_node(self, node_id), last_node = node;
22 uint32_t next_id;
23
24 bool match = false;
25 uint32_t index = 0;
26 uint32_t phrase_len = 0;
27 uint32_t phrase_start = 0;
28 uint32_t data;
29
30 trie_search_state_t state = SEARCH_STATE_BEGIN, last_state = SEARCH_STATE_BEGIN;
31
32 bool advance_index = true;
33
34 while(1) {
35 len = utf8proc_iterate(ptr, -1, &unich);
36 remaining = len;
37 if (len <= 0) return false;
38 if (!(utf8proc_codepoint_valid(unich))) return false;
39
40 int cat = utf8proc_category(unich);
41 bool is_letter = utf8_is_letter(cat);
42
43 // If we're in the middle of a word and the first letter was not a match, skip the word
44 if (is_letter && state == SEARCH_STATE_NO_MATCH) {
45 log_debug("skipping\n");
46 ptr += len;
47 index += len;
48 last_state = state;
49 continue;
50 }
51
52 // Match in the middle of a word
53 if (is_letter && last_state == SEARCH_STATE_MATCH) {
54 log_debug("last_state == SEARCH_STATE_MATCH && is_letter\n");
55 // Only set match to false so we don't callback
56 match = false;
57 }
58
59 for (int i=0; remaining > 0; remaining--, i++, ptr++, last_node=node, last_state=state, node_id=next_id) {
60 ch = (unsigned char) *ptr;
61 log_debug("char=%c\n", ch);
62
63 next_id = trie_get_transition_index(self, node, *ptr);
64 node = trie_get_node(self, next_id);
65
66 if (node.check != node_id) {
67 state = is_letter ? SEARCH_STATE_NO_MATCH : SEARCH_STATE_BEGIN;
68 if (match) {
69 log_debug("match is true and state==SEARCH_STATE_NO_MATCH\n");
70 if (*phrases == NULL) {
71 *phrases = phrase_array_new_size(1);
72 }
73 phrase_array_push(*phrases, (phrase_t){phrase_start, phrase_len, data});
74 index = phrase_start + phrase_len;
75 advance_index = false;
76 // Set the text back to the end of the last phrase
77 ptr = (const uint8_t *)text + index;
78 len = utf8proc_iterate(ptr, -1, &unich);
79 log_debug("ptr=%s\n", ptr);
80 } else {
81 ptr += remaining;
82 log_debug("done with char, now at %s\n", ptr);
83 }
84 fail_ptr = ptr;
85 node_id = start_node_id;
86 last_node = node = trie_get_node(self, node_id);
87 phrase_start = phrase_len = 0;
88 last_state = state;
89 match = false;
90 break;
91 } else {
92 log_debug("node.check == node_id\n");
93 state = SEARCH_STATE_PARTIAL_MATCH;
94 if (last_state == SEARCH_STATE_NO_MATCH || last_state == SEARCH_STATE_BEGIN) {
95 log_debug("phrase_start=%u\n", index);
96 phrase_start = index;
97 fail_ptr = ptr + remaining;
98 }
99
100 if (node.base < 0) {
101 int32_t data_index = -1*node.base;
102 trie_data_node_t data_node = self->data->a[data_index];
103 unsigned char *current_tail = self->tail->a + data_node.tail;
104
105 size_t tail_len = strlen((char *)current_tail);
106 char *query_tail = (char *)(*ptr ? ptr + 1 : ptr);
107 size_t query_tail_len = strlen((char *)query_tail);
108 log_debug("next node tail: %s\n", current_tail);
109 log_debug("query node tail: %s\n", query_tail);
110
111 if (tail_len <= query_tail_len && strncmp((char *)current_tail, query_tail, tail_len) == 0) {
112 state = SEARCH_STATE_MATCH;
113 log_debug("Tail matches\n");
114 last_state = state;
115 data = data_node.data;
116 log_debug("%u, %d, %zu\n", index, phrase_len, tail_len);
117 ptr += tail_len;
118 index += tail_len;
119 advance_index = false;
120 phrase_len = index + 1 - phrase_start;
121 match = true;
122 } else if (match) {
123 log_debug("match is true and longer phrase tail did not match\n");
124 log_debug("phrase_start=%d, phrase_len=%d\n", phrase_start, phrase_len);
125 if (*phrases == NULL) {
126 *phrases = phrase_array_new_size(1);
127 }
128 phrase_array_push(*phrases, (phrase_t){phrase_start, phrase_len, data});
129 ptr = fail_ptr;
130 match = false;
131 index = phrase_start + phrase_len;
132 advance_index = false;
133 }
134
135 }
136
137 if (ch != '\0') {
138 trie_node_t terminal_node = trie_get_transition(self, node, '\0');
139 if (terminal_node.check == next_id) {
140 log_debug("Transition to NUL byte matched\n");
141 state = SEARCH_STATE_MATCH;
142 match = true;
143 phrase_len = index + (uint32_t)len - phrase_start;
144 if (terminal_node.base < 0) {
145 int32_t data_index = -1*terminal_node.base;
146 trie_data_node_t data_node = self->data->a[data_index];
147 data = data_node.data;
148 }
149 log_debug("Got match with len=%d\n", phrase_len);
150 fail_ptr = ptr;
151 }
152 }
153 }
154
155 }
156
157 if (unich == 0) {
158 if (last_state == SEARCH_STATE_MATCH) {
159 log_debug("Found match at the end\n");
160 if (*phrases == NULL) {
161 *phrases = phrase_array_new_size(1);
162 }
163 phrase_array_push(*phrases, (phrase_t){phrase_start, phrase_len, data});
164 }
165 break;
166 }
167
168 if (advance_index) index += len;
169
170 advance_index = true;
171 log_debug("index now %u\n", index);
172 } // while
173
174 return true;
175 }
176
trie_search_with_phrases(trie_t * self,char * str,phrase_array ** phrases)177 inline bool trie_search_with_phrases(trie_t *self, char *str, phrase_array **phrases) {
178 return trie_search_from_index(self, str, ROOT_NODE_ID, phrases);
179 }
180
trie_search(trie_t * self,char * text)181 inline phrase_array *trie_search(trie_t *self, char *text) {
182 phrase_array *phrases = NULL;
183 if (!trie_search_with_phrases(self, text, &phrases)) {
184 return false;
185 }
186 return phrases;
187 }
188
trie_node_search_tail_tokens(trie_t * self,trie_node_t node,char * str,token_array * tokens,size_t tail_index,int token_index)189 int trie_node_search_tail_tokens(trie_t *self, trie_node_t node, char *str, token_array *tokens, size_t tail_index, int token_index) {
190 int32_t data_index = -1*node.base;
191 trie_data_node_t old_data_node = self->data->a[data_index];
192 uint32_t current_tail_pos = old_data_node.tail;
193
194 log_debug("tail_index = %zu\n", tail_index);
195
196 unsigned char *tail_ptr = self->tail->a + current_tail_pos + tail_index;
197
198 if (!(*tail_ptr)) {
199 log_debug("tail matches!\n");
200 return token_index - 1;
201 }
202
203 log_debug("Searching tail: %s\n", tail_ptr);
204 size_t num_tokens = tokens->n;
205 for (int i = token_index; i < num_tokens; i++) {
206 token_t token = tokens->a[i];
207
208 char *ptr = str + token.offset;
209 size_t token_length = token.len;
210
211 if (!(*tail_ptr)) {
212 log_debug("tail matches!\n");
213 return i - 1;
214 }
215
216 if (token.type == WHITESPACE && *tail_ptr == ' ') continue;
217
218 if (*tail_ptr == ' ') {
219 tail_ptr++;
220 log_debug("Got space, advancing pointer, tail_ptr=%s\n", tail_ptr);
221 }
222
223 log_debug("Tail string compare: %s with %.*s\n", tail_ptr, (int)token_length, ptr);
224
225 if (strncmp((char *)tail_ptr, ptr, token_length) == 0) {
226 tail_ptr += token_length;
227
228 if (i == num_tokens - 1 && !(*tail_ptr)) {
229 return i;
230 }
231 } else {
232 return -1;
233 }
234 }
235 return -1;
236
237 }
238
239
trie_search_tokens_from_index(trie_t * self,char * str,token_array * tokens,uint32_t start_node_id,phrase_array ** phrases)240 bool trie_search_tokens_from_index(trie_t *self, char *str, token_array *tokens, uint32_t start_node_id, phrase_array **phrases) {
241 if (str == NULL || tokens == NULL || tokens->n == 0) return false;
242
243 uint32_t node_id = start_node_id, last_node_id = start_node_id;
244 trie_node_t node = trie_get_node(self, node_id), last_node = node;
245
246 uint32_t data;
247
248 int phrase_len = 0, phrase_start = 0, last_match_index = -1;
249
250 trie_search_state_t state = SEARCH_STATE_BEGIN, last_state = SEARCH_STATE_BEGIN;
251
252 token_t token;
253 size_t token_length;
254
255 log_debug("num_tokens: %zu\n", tokens->n);
256 for (int i = 0; i < tokens->n; i++, last_state = state) {
257 token = tokens->a[i];
258 token_length = token.len;
259
260 char *ptr = str + token.offset;
261 log_debug("On %d, token=%.*s\n", i, (int)token_length, ptr);
262
263 bool check_continuation = true;
264
265 if (token.type != WHITESPACE) {
266 for (int j = 0; j < token_length; j++, ptr++, last_node = node, last_node_id = node_id) {
267 log_debug("Getting transition index for %d, (%d, %d)\n", node_id, node.base, node.check);
268 size_t offset = j + 1;
269 if (j > 0 || last_node.base >= 0) {
270 node_id = trie_get_transition_index(self, node, *ptr);
271 node = trie_get_node(self, node_id);
272 log_debug("Doing %c, got node_id=%d\n", *ptr, node_id);
273 } else {
274 log_debug("Tail stored on space node, rolling back one character\n");
275 ptr--;
276 offset = j;
277 log_debug("ptr=%s\n", ptr);
278 }
279
280 if (node.check != last_node_id && last_node.base >= 0) {
281 log_debug("Fell off trie. last_node_id=%d and node.check=%d\n", last_node_id, node.check);
282 node_id = last_node_id = start_node_id;
283 node = last_node = trie_get_node(self, node_id);
284 break;
285 } else if (node.base < 0) {
286 log_debug("Searching tail at index %d\n", i);
287
288 uint32_t data_index = -1*node.base;
289 trie_data_node_t data_node = self->data->a[data_index];
290 uint32_t current_tail_pos = data_node.tail;
291
292 unsigned char *current_tail = self->tail->a + current_tail_pos;
293
294 log_debug("token_length = %zu, j=%d\n", token_length, j);
295
296 size_t ptr_len = token_length - offset;
297 log_debug("next node tail: %s vs %.*s\n", current_tail, (int)ptr_len, ptr + 1);
298
299 if (last_state == SEARCH_STATE_NO_MATCH || last_state == SEARCH_STATE_BEGIN) {
300 log_debug("phrase start at %d\n", i);
301 phrase_start = i;
302 }
303 if (strncmp((char *)current_tail, ptr + 1, ptr_len) == 0) {
304 log_debug("node tail matches first token\n");
305 int tail_search_result = trie_node_search_tail_tokens(self, node, str, tokens, ptr_len, i + 1);
306 log_debug("tail_search_result=%d\n", tail_search_result);
307 node_id = start_node_id;
308 node = trie_get_node(self, node_id);
309 check_continuation = false;
310
311 if (tail_search_result != -1) {
312 phrase_len = tail_search_result - phrase_start + 1;
313 last_match_index = i = tail_search_result;
314 last_state = SEARCH_STATE_MATCH;
315 data = data_node.data;
316 }
317 break;
318
319 } else {
320 node_id = start_node_id;
321 node = trie_get_node(self, node_id);
322 break;
323 }
324 }
325 }
326 } else {
327 check_continuation = false;
328 if (state == SEARCH_STATE_BEGIN || state == SEARCH_STATE_NO_MATCH) {
329 continue;
330 }
331 }
332
333
334 if (node.check <= 0 || node_id == start_node_id) {
335 log_debug("state = SEARCH_STATE_NO_MATCH\n");
336 state = SEARCH_STATE_NO_MATCH;
337 // check
338 if (last_match_index != -1) {
339 log_debug("last_match not NULL and state==SEARCH_STATE_NO_MATCH, data=%d\n", data);
340 if (*phrases == NULL) {
341 *phrases = phrase_array_new_size(1);
342 }
343 phrase_array_push(*phrases, (phrase_t){phrase_start, last_match_index - phrase_start + 1, data});
344 i = last_match_index;
345 last_match_index = -1;
346 phrase_start = phrase_len = 0;
347 node_id = last_node_id = start_node_id;
348 node = last_node = trie_get_node(self, start_node_id);
349 continue;
350 } else if (last_state == SEARCH_STATE_PARTIAL_MATCH) {
351 // we're in the middle of a phrase that did not fully
352 // match the trie. Imagine a trie that contains "a b c" and "b"
353 // and our string is "a b d". When we get to "d", we've matched
354 // the prefix "a b" and then will fall off the trie. So we'll actually
355 // need to go back to token "b" and start searching from the root
356
357 // Note: i gets incremented with the next iteration of the for loop
358 i = phrase_start;
359 log_debug("last_state == SEARCH_STATE_PARTIAL_MATCH, i = %d\n", i);
360 last_match_index = -1;
361 phrase_start = phrase_len = 0;
362 node_id = last_node_id = start_node_id;
363 node = last_node = trie_get_node(self, start_node_id);
364 continue;
365 } else {
366 phrase_start = phrase_len = 0;
367 // this token was not a phrase
368 log_debug("Plain token=%.*s\n", (int)token.len, str + token.offset);
369 }
370 node_id = last_node_id = start_node_id;
371 node = last_node = trie_get_node(self, start_node_id);
372 } else {
373
374 state = SEARCH_STATE_PARTIAL_MATCH;
375 if (!(node.base < 0) && (last_state == SEARCH_STATE_NO_MATCH || last_state == SEARCH_STATE_BEGIN)) {
376 log_debug("phrase_start=%d, node.base = %d, last_state=%d\n", i, node.base, last_state);
377 phrase_start = i;
378 }
379
380 trie_node_t terminal_node = trie_get_transition(self, node, '\0');
381 if (terminal_node.check == node_id) {
382 log_debug("node match at %d\n", i);
383 state = SEARCH_STATE_MATCH;
384 int32_t data_index = -1*terminal_node.base;
385 trie_data_node_t data_node = self->data->a[data_index];
386 data = data_node.data;
387 log_debug("data = %d\n", data);
388
389 log_debug("phrase_start = %d\n", phrase_start);
390
391 last_match_index = i;
392 log_debug("last_match_index = %d\n", i);
393 }
394
395 if (i == tokens->n - 1) {
396 if (last_match_index == -1) {
397 log_debug("At last token\n");
398 break;
399 } else {
400 if (*phrases == NULL) {
401 *phrases = phrase_array_new_size(1);
402 }
403 phrase_array_push(*phrases, (phrase_t){phrase_start, last_match_index - phrase_start + 1, data});
404 i = last_match_index;
405 last_match_index = -1;
406 phrase_start = phrase_len = 0;
407 node_id = last_node_id = start_node_id;
408 node = last_node = trie_get_node(self, start_node_id);
409 state = SEARCH_STATE_NO_MATCH;
410 continue;
411 }
412 }
413
414 if (check_continuation) {
415
416 // Check continuation
417 uint32_t continuation_id = trie_get_transition_index(self, node, ' ');
418 log_debug("transition_id: %u\n", continuation_id);
419 trie_node_t continuation = trie_get_node(self, continuation_id);
420
421 if (token.type == IDEOGRAPHIC_CHAR && continuation.check != node_id) {
422 log_debug("Ideographic character\n");
423 last_node_id = node_id;
424 last_node = node;
425 } else if (continuation.check != node_id && last_match_index != -1) {
426 log_debug("node->match no continuation\n");
427 if (*phrases == NULL) {
428 *phrases = phrase_array_new_size(1);
429 }
430 phrase_array_push(*phrases, (phrase_t){phrase_start, last_match_index - phrase_start + 1, data});
431 i = last_match_index;
432 last_match_index = -1;
433 phrase_start = phrase_len = 0;
434 node_id = last_node_id = start_node_id;
435 node = last_node = trie_get_node(self, start_node_id);
436 state = SEARCH_STATE_BEGIN;
437 } else if (continuation.check != node_id) {
438 log_debug("No continuation for phrase with start=%d, yielding tokens\n", phrase_start);
439 state = SEARCH_STATE_NO_MATCH;
440 phrase_start = phrase_len = 0;
441 node_id = last_node_id = start_node_id;
442 node = last_node = trie_get_node(self, start_node_id);
443 } else {
444 log_debug("Has continuation, node_id=%d\n", continuation_id);
445 last_node = node = continuation;
446 last_node_id = node_id = continuation_id;
447 }
448 }
449 }
450
451 }
452
453 if (last_match_index != -1) {
454 if (*phrases == NULL) {
455 *phrases = phrase_array_new_size(1);
456 }
457 log_debug("adding phrase, last_match_index=%d\n", last_match_index);
458 phrase_array_push(*phrases, (phrase_t){phrase_start, last_match_index - phrase_start + 1, data});
459 }
460
461 return true;
462 }
463
trie_search_tokens_with_phrases(trie_t * self,char * str,token_array * tokens,phrase_array ** phrases)464 inline bool trie_search_tokens_with_phrases(trie_t *self, char *str, token_array *tokens, phrase_array **phrases) {
465 return trie_search_tokens_from_index(self, str, tokens, ROOT_NODE_ID, phrases);
466 }
467
trie_search_tokens(trie_t * self,char * str,token_array * tokens)468 inline phrase_array *trie_search_tokens(trie_t *self, char *str, token_array *tokens) {
469 phrase_array *phrases = NULL;
470 if (!trie_search_tokens_with_phrases(self, str, tokens, &phrases)) {
471 return NULL;
472 }
473 return phrases;
474 }
475
trie_search_suffixes_from_index(trie_t * self,char * word,size_t len,uint32_t start_node_id)476 phrase_t trie_search_suffixes_from_index(trie_t *self, char *word, size_t len, uint32_t start_node_id) {
477 uint32_t last_node_id = start_node_id;
478 trie_node_t last_node = trie_get_node(self, last_node_id);
479 uint32_t node_id = last_node_id;
480 trie_node_t node = last_node;
481
482 uint32_t value = 0, phrase_start = 0, phrase_len = 0;
483
484 ssize_t char_len;
485
486 int32_t unich = 0;
487
488 ssize_t index = len;
489 const uint8_t *ptr = (const uint8_t *)word;
490 const uint8_t *char_ptr;
491
492 bool in_tail = false;
493 unsigned char *current_tail = (unsigned char *)"";
494 size_t tail_remaining = 0;
495
496 uint32_t tail_value = 0;
497
498 while(index > 0) {
499 char_len = utf8proc_iterate_reversed(ptr, index, &unich);
500
501 if (char_len <= 0) return NULL_PHRASE;
502 if (!(utf8proc_codepoint_valid(unich))) return NULL_PHRASE;
503
504 index -= char_len;
505 char_ptr = ptr + index;
506
507 if (in_tail && tail_remaining >= char_len && strncmp((char *)current_tail, (char *)char_ptr, char_len) == 0) {
508 tail_remaining -= char_len;
509 current_tail += char_len;
510 phrase_start = (uint32_t)index;
511
512 log_debug("tail matched at char %.*s (len=%zd)\n", (int)char_len, char_ptr, char_len);
513 log_debug("tail_remaining = %zu\n", tail_remaining);
514
515 if (tail_remaining == 0) {
516 log_debug("tail match! tail_value=%u\n",tail_value);
517 phrase_len = (uint32_t)(len - index);
518 value = tail_value;
519 index = 0;
520 break;
521 }
522 continue;
523 } else if (in_tail) {
524 break;
525 }
526
527 for (int i=0; i < char_len; i++, char_ptr++, last_node = node, last_node_id = node_id) {
528 log_debug("char=%c\n", (unsigned char)*char_ptr);
529
530 node_id = trie_get_transition_index(self, node, *char_ptr);
531 node = trie_get_node(self, node_id);
532
533 if (node.check != last_node_id) {
534 log_debug("node.check = %d and last_node_id = %d\n", node.check, last_node_id);
535 index = 0;
536 break;
537 } else if (node.base < 0) {
538 log_debug("Searching tail\n");
539
540 uint32_t data_index = -1*node.base;
541 trie_data_node_t data_node = self->data->a[data_index];
542 uint32_t current_tail_pos = data_node.tail;
543
544 tail_value = data_node.data;
545
546 current_tail = self->tail->a + current_tail_pos;
547
548 tail_remaining = strlen((char *)current_tail);
549 log_debug("tail_remaining=%zu\n", tail_remaining);
550 in_tail = true;
551
552 size_t remaining_char_len = char_len - i - 1;
553 log_debug("remaining_char_len = %zu\n", remaining_char_len);
554
555 if (remaining_char_len > 0 && strncmp((char *)char_ptr + 1, (char *)current_tail, remaining_char_len) == 0) {
556 log_debug("tail string comparison successful\n");
557 tail_remaining -= remaining_char_len;
558 current_tail += remaining_char_len;
559 } else if (remaining_char_len > 0) {
560 log_debug("tail comparison unsuccessful, %s vs %s\n", char_ptr, current_tail);
561 index = 0;
562 break;
563 }
564
565 if (tail_remaining == 0) {
566 phrase_start = (uint32_t)index;
567 phrase_len = (uint32_t)(len - index);
568 log_debug("phrase_start = %d, phrase_len=%d\n", phrase_start, phrase_len);
569 value = tail_value;
570 index = 0;
571 }
572 break;
573 } else if (i == char_len - 1) {
574 trie_node_t terminal_node = trie_get_transition(self, node, '\0');
575 if (terminal_node.check == node_id) {
576 int32_t data_index = -1 * terminal_node.base;
577 trie_data_node_t data_node = self->data->a[data_index];
578 value = data_node.data;
579 phrase_start = (uint32_t)index;
580 phrase_len = (uint32_t)(len - index);
581 }
582 }
583
584 }
585
586 }
587
588 return (phrase_t) {phrase_start, phrase_len, value};
589 }
590
trie_search_suffixes_from_index_get_suffix_char(trie_t * self,char * word,size_t len,uint32_t start_node_id)591 inline phrase_t trie_search_suffixes_from_index_get_suffix_char(trie_t *self, char *word, size_t len, uint32_t start_node_id) {
592 if (word == NULL || len == 0) return NULL_PHRASE;
593 trie_node_t node = trie_get_node(self, start_node_id);
594 unsigned char suffix_char = TRIE_SUFFIX_CHAR[0];
595 uint32_t node_id = trie_get_transition_index(self, node, suffix_char);
596 node = trie_get_node(self, node_id);
597
598 if (node.check != start_node_id) {
599 log_debug("node.check != start_node_id\n");
600 return NULL_PHRASE;
601 }
602
603 return trie_search_suffixes_from_index(self, word, len, node_id);
604 }
605
trie_search_suffixes(trie_t * self,char * word,size_t len)606 inline phrase_t trie_search_suffixes(trie_t *self, char *word, size_t len) {
607 if (word == NULL || len == 0) return NULL_PHRASE;
608 return trie_search_suffixes_from_index_get_suffix_char(self, word, len, ROOT_NODE_ID);
609 }
610
611
trie_search_prefixes_from_index(trie_t * self,char * word,size_t len,uint32_t start_node_id)612 phrase_t trie_search_prefixes_from_index(trie_t *self, char *word, size_t len, uint32_t start_node_id) {
613 log_debug("Call to trie_search_prefixes_from_index\n");
614 uint32_t node_id = start_node_id, last_node_id = node_id;
615 trie_node_t node = trie_get_node(self, node_id), last_node = node;
616
617 log_debug("last_node_id = %d\n", last_node_id);
618
619 uint32_t value = 0, phrase_start = 0, phrase_len = 0;
620
621 uint8_t *ptr = (uint8_t *)word;
622
623 ssize_t char_len = 0;
624
625 uint32_t idx = 0;
626
627 size_t separator_char_len = 0;
628
629 int32_t codepoint = 0;
630
631 bool first_char = true;
632
633 trie_data_node_t data_node;
634 trie_node_t terminal_node;
635
636 bool phrase_at_hyphen = false;
637
638 while (idx < len) {
639 char_len = utf8proc_iterate(ptr, len, &codepoint);
640 log_debug("char_len = %zu, char=%d\n", char_len, codepoint);
641 if (char_len <= 0) break;
642 if (!(utf8proc_codepoint_valid(codepoint))) break;
643
644 bool is_hyphen = utf8_is_hyphen(codepoint);
645
646 int cat = utf8proc_category(codepoint);
647 bool is_space = utf8_is_separator(cat);
648
649 uint8_t *char_ptr = ptr;
650 size_t i = 0;
651
652 bool skip_char = false;
653 bool break_out = false;
654
655 for (i = 0; i < char_len; i++) {
656 node_id = trie_get_transition_index(self, last_node, *char_ptr);
657 node = trie_get_node(self, node_id);
658 log_debug("At idx=%u, i=%zu, char=%.*s\n", idx, i, (int)char_len, char_ptr);
659
660 if (node.check != last_node_id) {
661 log_debug("node.check = %d and last_node_id = %d\n", node.check, last_node_id);
662
663 if (is_hyphen || (is_space && *ptr != ' ')) {
664 log_debug("Got hyphen or other separator, trying space instead\n");
665 node_id = trie_get_transition_index(self, last_node, ' ');
666 node = trie_get_node(self, node_id);
667 }
668
669 if (is_hyphen && node.check != last_node_id) {
670 log_debug("No space transition, phrase_len=%zu\n", phrase_len);
671 if (phrase_len > 0 && phrase_len == idx) {
672 log_debug("phrase_at_hyphen\n");
673 phrase_at_hyphen = true;
674 }
675
676 ptr += char_len;
677 idx += char_len;
678 separator_char_len = char_len;
679 node_id = last_node_id;
680 node = trie_get_node(self, node_id);
681 skip_char = true;
682 break;
683 } else if (node.check != last_node_id) {
684 break_out = true;
685 log_debug("Breaking\n");
686 break;
687 }
688 break;
689 }
690
691 if (first_char) {
692 phrase_start = idx;
693 first_char = false;
694 }
695
696 if (node.base < 0) {
697 log_debug("Searching tail\n");
698
699 data_node = trie_get_data_node(self, node);
700 uint32_t current_tail_pos = data_node.tail;
701
702 unsigned char *current_tail = self->tail->a + current_tail_pos;
703
704 log_debug("comparing tail: %s vs %s\n", current_tail, char_ptr + 1);
705 size_t current_tail_len = strlen((char *)current_tail);
706
707 size_t match_len = i + 1;
708 size_t offset = i + 1;
709 size_t tail_pos = 0;
710 log_debug("offset=%zu\n", offset);
711
712 if (char_len > 1) {
713 log_debug("char_len = %zu\n", char_len);
714 log_debug("Doing strncmp: (%zu) %s vs %s\n", char_len - offset, current_tail, char_ptr + 1);
715
716 if (strncmp((char *)ptr + offset, (char *)current_tail, char_len - offset) == 0) {
717 match_len += char_len - offset;
718 tail_pos = char_len - offset;
719 log_debug("in char match_len = %zu\n", match_len);
720 } else {
721 return NULL_PHRASE;
722 }
723 }
724
725 size_t tail_match_len = utf8_common_prefix_len((char *)ptr + char_len, (char *)current_tail + tail_pos, current_tail_len - tail_pos);
726 match_len += tail_match_len;
727 log_debug("match_len=%zu\n", match_len);
728
729 if (tail_match_len == current_tail_len - tail_pos) {
730 if (phrase_at_hyphen) {
731 char_len = utf8proc_iterate(ptr + char_len, len, &codepoint);
732 if (char_len > 0 && utf8proc_codepoint_valid(codepoint)) {
733 int cat = utf8proc_category(codepoint);
734
735 if (codepoint != 0 && !utf8_is_hyphen(codepoint) && !utf8_is_separator(cat) && !utf8_is_punctuation(cat)) {
736 return (phrase_t){phrase_start, phrase_len, value};
737 }
738 }
739 }
740 if (first_char) phrase_start = idx;
741 phrase_len = (uint32_t)(idx + match_len) - phrase_start;
742
743 log_debug("tail match! phrase_len=%u, len=%zu\n", phrase_len, len);
744 value = data_node.data;
745 return (phrase_t){phrase_start, phrase_len, value};
746 } else {
747 return NULL_PHRASE;
748 }
749
750 } else if (node.check == last_node_id) {
751 terminal_node = trie_get_transition(self, node, '\0');
752 log_debug("Trying link from %d to terminal node\n", last_node_id);
753
754 if (terminal_node.check == node_id) {
755 log_debug("Transition to NUL byte matched\n");
756 if (terminal_node.base < 0) {
757 phrase_len = (uint32_t)(idx + char_len) - phrase_start;
758 data_node = trie_get_data_node(self, terminal_node);
759 value = data_node.data;
760 }
761 log_debug("Got match with len=%d\n", phrase_len);
762 }
763 }
764
765 last_node = node;
766 last_node_id = node_id;
767 log_debug("last_node_id = %d\n", last_node_id);
768 char_ptr++;
769 }
770
771
772 if (break_out) {
773 break;
774 } else if (skip_char) {
775 continue;
776 }
777
778 log_debug("Incrementing index\n");
779
780 idx += char_len;
781 ptr += char_len;
782 }
783
784 log_debug("exited while loop\n");
785
786 if (phrase_len == 0) return NULL_PHRASE;
787
788 return (phrase_t) {phrase_start, phrase_len, value};
789 }
790
trie_search_prefixes_from_index_get_prefix_char(trie_t * self,char * word,size_t len,uint32_t start_node_id)791 inline phrase_t trie_search_prefixes_from_index_get_prefix_char(trie_t *self, char *word, size_t len, uint32_t start_node_id) {
792 trie_node_t node = trie_get_node(self, start_node_id);
793 unsigned char prefix_char = TRIE_PREFIX_CHAR[0];
794 uint32_t node_id = trie_get_transition_index(self, node, prefix_char);
795 node = trie_get_node(self, node_id);
796
797 if (node.check != start_node_id) {
798 return NULL_PHRASE;
799 }
800
801 return trie_search_prefixes_from_index(self, word, len, node_id);
802 }
803
trie_search_prefixes(trie_t * self,char * word,size_t len)804 inline phrase_t trie_search_prefixes(trie_t *self, char *word, size_t len) {
805 if (word == NULL || len == 0) return NULL_PHRASE;
806 return trie_search_prefixes_from_index_get_prefix_char(self, word, len, ROOT_NODE_ID);
807 }
808
token_phrase_memberships(phrase_array * phrases,int64_array * phrase_memberships,size_t len)809 bool token_phrase_memberships(phrase_array *phrases, int64_array *phrase_memberships, size_t len) {
810 if (phrases == NULL || phrase_memberships == NULL) {
811 return false;
812 }
813
814 int64_t i = 0;
815 for (int64_t j = 0; j < phrases->n; j++) {
816 phrase_t phrase = phrases->a[j];
817
818 for (; i < phrase.start; i++) {
819 int64_array_push(phrase_memberships, NULL_PHRASE_MEMBERSHIP);
820 log_debug("token i=%" PRId64 ", null phrase membership\n", i);
821 }
822
823 for (i = phrase.start; i < phrase.start + phrase.len; i++) {
824 log_debug("token i=%" PRId64 ", phrase membership=%" PRId64 "\n", i, j);
825 int64_array_push(phrase_memberships, j);
826 }
827 }
828
829 for (; i < len; i++) {
830 log_debug("token i=%" PRId64 ", null phrase membership\n", i);
831 int64_array_push(phrase_memberships, NULL_PHRASE_MEMBERSHIP);
832 }
833
834 return true;
835 }
836
cstring_array_get_phrase(cstring_array * str,char_array * phrase_tokens,phrase_t phrase)837 inline char *cstring_array_get_phrase(cstring_array *str, char_array *phrase_tokens, phrase_t phrase) {
838 char_array_clear(phrase_tokens);
839
840 size_t phrase_end = phrase.start + phrase.len;
841
842 for (int k = phrase.start; k < phrase_end; k++) {
843 char *w = cstring_array_get_string(str, k);
844 char_array_append(phrase_tokens, w);
845 if (k < phrase_end - 1) {
846 char_array_append(phrase_tokens, " ");
847 }
848 }
849 char_array_terminate(phrase_tokens);
850
851 return char_array_get_string(phrase_tokens);
852 }
853
854