1 /**
2
3 C-Template -- Boilerplate c project with cmake support, CuTest unit testing, and more.
4
5 @file aho-corasick.c
6
7 @brief C implementation of the Aho-Corasick algorithm for searching text
8 for multiple strings simultaneously in a single pass without backtracking.
9
10 <https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm>
11
12
13 @author Fletcher T. Penney
14 @bug
15
16 **/
17
18 /*
19
20 Copyright © 2015-2017 Fletcher T. Penney.
21
22
23 The `c-template` project is released under the MIT License.
24
25 GLibFacade.c and GLibFacade.h are from the MultiMarkdown v4 project:
26
27 https://github.com/fletcher/MultiMarkdown-4/
28
29 MMD 4 is released under both the MIT License and GPL.
30
31
32 CuTest is released under the zlib/libpng license. See CuTest.c for the text
33 of the license.
34
35
36 ## The MIT License ##
37
38 Permission is hereby granted, free of charge, to any person obtaining a copy
39 of this software and associated documentation files (the "Software"), to deal
40 in the Software without restriction, including without limitation the rights
41 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
42 copies of the Software, and to permit persons to whom the Software is
43 furnished to do so, subject to the following conditions:
44
45 The above copyright notice and this permission notice shall be included in
46 all copies or substantial portions of the Software.
47
48 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
49 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
50 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
51 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
52 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
53 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
54 THE SOFTWARE.
55
56 */
57
58 #include <stdio.h>
59 #include <stdlib.h>
60 #include <stdbool.h>
61
62 #include "aho-corasick.h"
63
64 #define kTrieStartingSize 256
65
66 void trie_to_graphviz(trie * a);
67
68
trie_new(size_t startingSize)69 trie * trie_new(size_t startingSize) {
70 trie * a = malloc(sizeof(trie));
71
72 if (a) {
73 if (startingSize <= 1) {
74 startingSize = kTrieStartingSize;
75 }
76
77 a->node = malloc(sizeof(trie_node) * startingSize);
78
79 if (!a->node) {
80 free(a);
81 return NULL;
82 }
83
84 // Clear memory
85 memset(a->node, 0, sizeof(trie_node) * startingSize);
86
87 // All tries have a root node
88 a->size = 1;
89 a->capacity = startingSize;
90 }
91
92 return a;
93 }
94
95
trie_free(trie * a)96 void trie_free(trie * a) {
97 free(a->node);
98 free(a);
99 }
100
101
trie_node_insert(trie * a,size_t s,const unsigned char * key,unsigned short match_type,unsigned short depth)102 bool trie_node_insert(trie * a, size_t s, const unsigned char * key, unsigned short match_type, unsigned short depth) {
103 // Get node for state s
104 trie_node * n = &a->node[s];
105
106 size_t i;
107
108 if (key[0] == '\0') {
109 // We've hit end of key
110 n->match_type = match_type;
111 n->len = depth;
112 return true; // Success
113 }
114
115 if (n->child[key[0]] != 0) {
116 // First character already in trie, advance forward
117 return trie_node_insert(a, n->child[key[0]], key + 1, match_type, ++depth);
118 } else {
119 // Create new node
120
121 // Ensure capacity
122 if (a->size == a->capacity) {
123 a->capacity *= 2;
124 a->node = realloc(a->node, a->capacity * sizeof(trie_node));
125
126 // Set n to new location
127 n = &(a->node[s]);
128 }
129
130 // Current node points to next node
131 i = a->size;
132 n->child[key[0]] = i;
133
134 // Initialize new node to 0
135 n = &a->node[i];
136 memset(n, 0, sizeof(trie_node));
137
138 // Set char for new node
139 n->c = key[0];
140
141 // Incremement size
142 a->size++;
143
144 // Advance forward
145 return trie_node_insert(a, i, key + 1, match_type, ++depth);
146 }
147 }
148
149
trie_insert(trie * a,const char * key,unsigned short match_type)150 bool trie_insert(trie * a, const char * key, unsigned short match_type) {
151 if (a && key && (key[0] != '\0')) {
152 return trie_node_insert(a, 0, (const unsigned char *)key, match_type, 0);
153 }
154
155 return false;
156 }
157
158
159 #ifdef TEST
Test_trie_insert(CuTest * tc)160 void Test_trie_insert(CuTest * tc) {
161 trie * a = trie_new(0);
162
163 CuAssertIntEquals(tc, kTrieStartingSize, a->capacity);
164 CuAssertIntEquals(tc, 1, a->size);
165
166 trie_insert(a, "foo", 42);
167
168 trie_node * n = &a->node[0];
169 CuAssertIntEquals(tc, 0, n->match_type);
170 CuAssertIntEquals(tc, 1, n->child['f']);
171 CuAssertIntEquals(tc, '\0', n->c);
172
173 n = &a->node[1];
174 CuAssertIntEquals(tc, 0, n->match_type);
175 CuAssertIntEquals(tc, 2, n->child['o']);
176 CuAssertIntEquals(tc, 'f', n->c);
177
178 n = &a->node[2];
179 CuAssertIntEquals(tc, 0, n->match_type);
180 CuAssertIntEquals(tc, 3, n->child['o']);
181 CuAssertIntEquals(tc, 'o', n->c);
182
183 n = &a->node[3];
184 CuAssertIntEquals(tc, 42, n->match_type);
185 CuAssertIntEquals(tc, 3, n->len);
186 CuAssertIntEquals(tc, 'o', n->c);
187
188 trie_free(a);
189 }
190 #endif
191
192
trie_node_search(trie * a,size_t s,const char * query)193 size_t trie_node_search(trie * a, size_t s, const char * query) {
194 if (query[0] == '\0') {
195 // Found matching state
196 return s;
197 }
198
199 if (a->node[s].child[(unsigned char)query[0]] == 0) {
200 // Failed to match
201 return -1;
202 }
203
204 // Partial match, keep going
205 return trie_node_search(a, a->node[s].child[(unsigned char)query[0]], query + 1);
206 }
207
208
trie_search(trie * a,const char * query)209 size_t trie_search(trie * a, const char * query) {
210 if (a && query) {
211 return trie_node_search(a, 0, query);
212 }
213
214 return 0;
215 }
216
217
trie_search_match_type(trie * a,const char * query)218 unsigned short trie_search_match_type(trie * a, const char * query) {
219 size_t s = trie_search(a, query);
220
221 if (s == -1) {
222 return -1;
223 }
224
225 return a->node[s].match_type;
226 }
227
228
229 #ifdef TEST
Test_trie_search(CuTest * tc)230 void Test_trie_search(CuTest * tc) {
231 trie * a = trie_new(0);
232
233 trie_insert(a, "foo", 42);
234 trie_insert(a, "bar", 41);
235 trie_insert(a, "food", 40);
236
237 CuAssertIntEquals(tc, 3, trie_search(a, "foo"));
238 CuAssertIntEquals(tc, 42, trie_search_match_type(a, "foo"));
239
240 CuAssertIntEquals(tc, 6, trie_search(a, "bar"));
241 CuAssertIntEquals(tc, 41, trie_search_match_type(a, "bar"));
242
243 CuAssertIntEquals(tc, 7, trie_search(a, "food"));
244 CuAssertIntEquals(tc, 40, trie_search_match_type(a, "food"));
245
246 CuAssertIntEquals(tc, -1, trie_search(a, "foot"));
247 CuAssertIntEquals(tc, (unsigned short) - 1, trie_search_match_type(a, "foot"));
248
249 trie_free(a);
250 }
251 #endif
252
253
ac_trie_node_prepare(trie * a,size_t s,char * buffer,unsigned short depth,size_t last_match_state)254 void ac_trie_node_prepare(trie * a, size_t s, char * buffer, unsigned short depth, size_t last_match_state) {
255
256 buffer[depth] = '\0';
257 buffer[depth + 1] = '\0';
258
259 // Current node
260 trie_node * n = &(a->node[s]);
261
262 char * suffix = buffer;
263
264 // Longest match seen so far??
265 suffix += 1;
266
267 // Find valid suffixes for failure path
268 while ((suffix[0] != '\0') && (n->ac_fail == 0)) {
269 n->ac_fail = trie_search(a, suffix);
270
271 if (n->ac_fail == -1) {
272 n->ac_fail = 0;
273 }
274
275 if (n->ac_fail == s) {
276 // Something went wrong
277 fprintf(stderr, "Recursive trie fallback detected at state %lu('%c') - suffix:'%s'!\n", s, n->c, suffix);
278 n->ac_fail = 0;
279 }
280
281 suffix++;
282 }
283
284
285 // Prepare children
286 for (int i = 0; i < 256; ++i) {
287 if ((n->child[i] != 0) &&
288 (n->child[i] != s)) {
289 buffer[depth] = i;
290
291 ac_trie_node_prepare(a, n->child[i], buffer, depth + 1, last_match_state);
292 }
293 }
294 }
295
296 /// Prepare trie for Aho-Corasick search algorithm by mapping failure connections
ac_trie_prepare(trie * a)297 void ac_trie_prepare(trie * a) {
298 // Clear old pointers
299 for (size_t i = 0; i < a->size; ++i) {
300 a->node[i].ac_fail = 0;
301 }
302
303 // Create a buffer to use
304 char buffer[a->capacity];
305
306 ac_trie_node_prepare(a, 0, buffer, 0, 0);
307 }
308
309
310
311 #ifdef TEST
Test_trie_prepare(CuTest * tc)312 void Test_trie_prepare(CuTest * tc) {
313 trie * a = trie_new(0);
314
315 trie_insert(a, "a", 1);
316 trie_insert(a, "aa", 2);
317 trie_insert(a, "aaa", 3);
318 trie_insert(a, "aaaa", 4);
319
320 ac_trie_prepare(a);
321
322 trie_free(a);
323 }
324 #endif
325
326
match_new(size_t start,size_t len,unsigned short match_type)327 match * match_new(size_t start, size_t len, unsigned short match_type) {
328 match * m = malloc(sizeof(match));
329
330 if (m) {
331 m->start = start;
332 m->len = len;
333 m->match_type = match_type;
334 m->next = NULL;
335 m->prev = NULL;
336 }
337
338 return m;
339 }
340
341
match_free(match * m)342 void match_free(match * m) {
343 if (m) {
344 if (m->next) {
345 match_free(m->next);
346 }
347
348 free(m);
349 }
350 }
351
352
match_add(match * last,size_t start,size_t len,unsigned short match_type)353 match * match_add(match * last, size_t start, size_t len, unsigned short match_type) {
354 if (last) {
355 last->next = match_new(start, len, match_type);
356 last->next->prev = last;
357 return last->next;
358 } else {
359 return match_new(start, len, match_type);
360 }
361
362 return NULL;
363 }
364
365
ac_trie_search(trie * a,const char * source,size_t start,size_t len)366 match * ac_trie_search(trie * a, const char * source, size_t start, size_t len) {
367
368 // Store results in a linked list
369 // match * result = match_new(0, 0, 0);
370 match * result = NULL;
371 match * m = result;
372
373 // Keep track of our state
374 size_t state = 0;
375 size_t temp_state;
376
377 // Character being compared
378 unsigned char test_value;
379 size_t counter = start;
380 size_t stop = start + len;
381
382 while ((counter < stop) && (source[counter] != '\0')) {
383 // Read next character
384 test_value = (unsigned char)source[counter++];
385
386 // Check for path that allows us to match next character
387 while (state != 0 && a->node[state].child[test_value] == 0) {
388 state = a->node[state].ac_fail;
389 }
390
391 // Advance state for the next character
392 state = a->node[state].child[test_value];
393
394 // Check for partial matches
395 temp_state = state;
396
397 while (temp_state != 0) {
398 if (a->node[temp_state].match_type) {
399 // This is a match
400 if (!m) {
401 result = match_new(0, 0, 0);
402 m = result;
403 }
404
405 m = match_add(m, counter - a->node[temp_state].len,
406 a->node[temp_state].len, a->node[temp_state].match_type);
407 }
408
409 // Iterate to find shorter matches
410 temp_state = a->node[temp_state].ac_fail;
411 }
412 }
413
414 return result;
415 }
416
417
match_excise(match * m)418 void match_excise(match * m) {
419 if (m) {
420 if (m->prev) {
421 m->prev->next = m->next;
422 }
423
424 if (m->next) {
425 m->next->prev = m->prev;
426 }
427
428 free(m);
429 }
430 }
431
432
match_count(match * m)433 int match_count(match * m) {
434 int result = 0;
435 m = m->next; // Skip header
436
437 while (m) {
438 result++;
439 m = m->next;
440 }
441
442 return result;
443 }
444
445
match_describe(match * m,const char * source)446 void match_describe(match * m, const char * source) {
447 fprintf(stderr, "'%.*s'(%d) at %lu:%lu\n", (int)m->len, &source[m->start],
448 m->match_type, m->start, m->start + m->len);
449 }
450
451
match_set_describe(match * m,const char * source)452 void match_set_describe(match * m, const char * source) {
453 m = m->next; // Skip header
454
455 while (m) {
456 match_describe(m, source);
457 m = m->next;
458 }
459 }
460
461
match_set_filter_leftmost_longest(match * header)462 void match_set_filter_leftmost_longest(match * header) {
463 // Filter results to include only leftmost/longest results
464 match * m = header->next; // Skip header
465 match * n;
466
467 while (m) {
468 if (m->next) {
469 if (m->start == m->next->start) {
470 // The next match is longer than this one
471 n = m;
472 m = m->next;
473 match_excise(n);
474 continue;
475 }
476
477 while (m->next &&
478 m->next->start > m->start &&
479 m->next->start < m->start + m->len) {
480 // This match is "lefter" than next
481 #ifndef __clang_analyzer__
482 match_excise(m->next);
483 #endif
484 }
485
486 while (m->next &&
487 m->next->start < m->start) {
488 // Next match is "lefter" than us
489 n = m;
490 m = m->prev;
491 match_excise(n);
492 }
493 }
494
495 while (m->prev &&
496 m->prev->len &&
497 m->prev->start >= m->start) {
498 // We are "lefter" than previous
499 n = m->prev;
500 #ifndef __clang_analyzer__
501 match_excise(n);
502 #endif
503 }
504
505 m = m->next;
506 }
507 }
508
509
ac_trie_leftmost_longest_search(trie * a,const char * source,size_t start,size_t len)510 match * ac_trie_leftmost_longest_search(trie * a, const char * source, size_t start, size_t len) {
511 match * result = ac_trie_search(a, source, start, len);
512
513 if (result) {
514 match_set_filter_leftmost_longest(result);
515 }
516
517 return result;
518 }
519
520
521 #ifdef TEST
Test_aho_trie_search(CuTest * tc)522 void Test_aho_trie_search(CuTest * tc) {
523 trie * a = trie_new(0);
524
525 trie_insert(a, "foo", 42);
526 trie_insert(a, "bar", 41);
527 trie_insert(a, "food", 40);
528
529 ac_trie_prepare(a);
530
531 match * m = ac_trie_search(a, "this is a bar that serves food.", 0, 31);
532
533 match_free(m);
534 trie_free(a);
535
536
537 a = trie_new(0);
538
539 trie_insert(a, "A", 1);
540 trie_insert(a, "AB", 2);
541 trie_insert(a, "ABC", 3);
542 trie_insert(a, "BC", 4);
543 trie_insert(a, "BCD", 5);
544 trie_insert(a, "E", 6);
545 trie_insert(a, "EFGHIJ", 7);
546 trie_insert(a, "F", 8);
547 trie_insert(a, "ZABCABCZ", 9);
548 trie_insert(a, "ZAB", 10);
549
550 ac_trie_prepare(a);
551
552 m = ac_trie_search(a, "ABCDEFGGGAZABCABCDZABCABCZ", 0, 26);
553 fprintf(stderr, "Finish with %d matches\n", match_count(m));
554 match_set_describe(m, "ABCDEFGGGAZABCABCDZABCABCZ");
555 match_free(m);
556
557 m = ac_trie_leftmost_longest_search(a, "ABCDEFGGGAZABCABCDZABCABCZ", 0, 26);
558 fprintf(stderr, "Finish with %d matches\n", match_count(m));
559 match_set_describe(m, "ABCDEFGGGAZABCABCDZABCABCZ");
560 match_free(m);
561
562 // trie_to_graphviz(a);
563
564 trie_free(a);
565 }
566 #endif
567
568
trie_node_to_graphviz(trie * a,size_t s)569 void trie_node_to_graphviz(trie * a, size_t s) {
570 trie_node * n = &a->node[s];
571
572 if (n->match_type) {
573 fprintf(stderr, "\"%lu\" [shape=doublecircle]\n", s);
574 }
575
576 for (int i = 0; i < 256; ++i) {
577 if (n->child[i]) {
578 switch (i) {
579 default:
580 fprintf(stderr, "\"%lu\" -> \"%lu\" [label=\"%c\"]\n", s, n->child[i], (char)i);
581 }
582 }
583 }
584
585 if (n->ac_fail) {
586 fprintf(stderr, "\"%lu\" -> \"%lu\" [label=\"fail\"]\n", s, n->ac_fail);
587 }
588 }
589
590
trie_to_graphviz(trie * a)591 void trie_to_graphviz(trie * a) {
592 fprintf(stderr, "digraph dfa {\n");
593
594 for (int i = 0; i < a->size; ++i) {
595 trie_node_to_graphviz(a, i);
596 }
597
598 fprintf(stderr, "}\n");
599 }
600
601