1 #include <ctype.h>
2 #include <strings.h> // for bzero
3 #include <algorithm>
4 #include "ac_slow.hpp"
5 #include "ac.h"
6 
7 //////////////////////////////////////////////////////////////////////////
8 //
9 //      Implementation of AhoCorasick_Slow
10 //
11 //////////////////////////////////////////////////////////////////////////
12 //
ACS_Constructor()13 ACS_Constructor::ACS_Constructor() : _next_node_id(1) {
14     _root = new_state();
15     _root_char = new InputTy[256];
16     bzero((void*)_root_char, 256);
17 
18 #ifdef VERIFY
19     _pattern_buf = 0;
20 #endif
21 }
22 
~ACS_Constructor()23 ACS_Constructor::~ACS_Constructor() {
24     for (std::vector<ACS_State* >::iterator i =  _all_states.begin(),
25             e = _all_states.end(); i != e; i++) {
26         delete *i;
27     }
28     _all_states.clear();
29     delete[] _root_char;
30 
31 #ifdef VERIFY
32     delete[] _pattern_buf;
33 #endif
34 }
35 
36 ACS_State*
new_state()37 ACS_Constructor::new_state() {
38     ACS_State* t = new ACS_State(_next_node_id++);
39     _all_states.push_back(t);
40     return t;
41 }
42 
43 void
Add_Pattern(const char * str,unsigned int str_len,int pattern_idx)44 ACS_Constructor::Add_Pattern(const char* str, unsigned int str_len,
45                              int pattern_idx) {
46     ACS_State* state = _root;
47     for (unsigned int i = 0; i < str_len; i++) {
48         const char c = str[i];
49         ACS_State* new_s = state->Get_Goto(c);
50         if (!new_s) {
51             new_s = new_state();
52             new_s->_depth = state->_depth + 1;
53             state->Set_Goto(c, new_s);
54         }
55         state = new_s;
56     }
57     state->_is_terminal = true;
58     state->set_Pattern_Idx(pattern_idx);
59 }
60 
61 void
Propagate_faillink()62 ACS_Constructor::Propagate_faillink() {
63     ACS_State* r = _root;
64     std::vector<ACS_State*> wl;
65 
66     const ACS_Goto_Map& m = r->Get_Goto_Map();
67     for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end(); i != e; i++) {
68         ACS_State* s = i->second;
69         s->_fail_link = r;
70         wl.push_back(s);
71     }
72 
73     // For any input c, make sure "goto(root, c)" is valid, which make the
74     // fail-link propagation lot easier.
75     ACS_Goto_Map goto_save = r->_goto_map;
76     for (uint32 i = 0; i <= 255; i++) {
77         ACS_State* s = r->Get_Goto(i);
78         if (!s) r->Set_Goto(i, r);
79     }
80 
81     for (uint32 i = 0; i < wl.size(); i++) {
82         ACS_State* s = wl[i];
83         ACS_State* fl = s->_fail_link;
84 
85         const ACS_Goto_Map& tran_map = s->Get_Goto_Map();
86 
87         for (ACS_Goto_Map::const_iterator ii = tran_map.begin(),
88                 ee = tran_map.end(); ii != ee; ii++) {
89             InputTy c = ii->first;
90             ACS_State *tran = ii->second;
91 
92             ACS_State* tran_fl = 0;
93             for (ACS_State* fl_walk = fl; ;) {
94                 if (ACS_State* t = fl_walk->Get_Goto(c)) {
95                     tran_fl = t;
96                     break;
97                 } else {
98                     fl_walk = fl_walk->Get_FailLink();
99                 }
100             }
101 
102             tran->_fail_link = tran_fl;
103             wl.push_back(tran);
104         }
105     }
106 
107     // Remove "goto(root, c) == root" transitions
108     r->_goto_map = goto_save;
109 }
110 
111 void
Construct(const char ** strv,unsigned int * strlenv,uint32 strnum)112 ACS_Constructor::Construct(const char** strv, unsigned int* strlenv,
113                            uint32 strnum) {
114     Save_Patterns(strv, strlenv, strnum);
115 
116     for (uint32 i = 0; i < strnum; i++) {
117         Add_Pattern(strv[i], strlenv[i], i);
118     }
119 
120     Propagate_faillink();
121     unsigned char* p = _root_char;
122 
123     const ACS_Goto_Map& m = _root->Get_Goto_Map();
124     for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end();
125             i != e; i++) {
126         p[i->first] = 1;
127     }
128 }
129 
130 Match_Result
MatchHelper(const char * str,uint32 len) const131 ACS_Constructor::MatchHelper(const char *str, uint32 len) const {
132     const ACS_State* root = _root;
133     const ACS_State* state = root;
134 
135     uint32 idx = 0;
136     while (idx < len) {
137         InputTy c = str[idx];
138         idx++;
139         if (_root_char[c]) {
140             state = root->Get_Goto(c);
141             break;
142         }
143     }
144 
145     if (unlikely(state->is_Terminal())) {
146         // This could happen if the one of the pattern has only one char!
147         uint32 pos = idx - 1;
148         Match_Result r(pos - state->Get_Depth() + 1, pos,
149                        state->get_Pattern_Idx());
150         return r;
151     }
152 
153     while (idx < len) {
154         InputTy c = str[idx];
155         ACS_State* gs = state->Get_Goto(c);
156 
157         if (!gs) {
158             ACS_State* fl = state->Get_FailLink();
159             if (fl == root) {
160                 while (idx < len) {
161                     InputTy c = str[idx];
162                     idx++;
163                     if (_root_char[c]) {
164                         state = root->Get_Goto(c);
165                         break;
166                     }
167                 }
168             } else {
169                 state = fl;
170             }
171         } else {
172             idx ++;
173             state = gs;
174         }
175 
176         if (state->is_Terminal()) {
177             uint32 pos = idx - 1;
178             Match_Result r = Match_Result(pos - state->Get_Depth() + 1, pos,
179                                           state->get_Pattern_Idx());
180             return r;
181         }
182     }
183 
184     return Match_Result(-1, -1, -1);
185 }
186 
187 #ifdef DEBUG
188 void
dump_text(const char * txtfile) const189 ACS_Constructor::dump_text(const char* txtfile) const {
190     FILE* f = fopen(txtfile, "w+");
191     for (std::vector<ACS_State*>::const_iterator i = _all_states.begin(),
192             e = _all_states.end(); i != e; i++) {
193         ACS_State* s = *i;
194 
195         fprintf(f, "S%d goto:{", s->Get_ID());
196         const ACS_Goto_Map& goto_func = s->Get_Goto_Map();
197 
198         for (ACS_Goto_Map::const_iterator i = goto_func.begin(), e = goto_func.end();
199               i != e; i++) {
200             InputTy input = i->first;
201             ACS_State* tran = i->second;
202             if (isprint(input))
203                 fprintf(f, "'%c' -> S:%d,", input, tran->Get_ID());
204             else
205                 fprintf(f, "%#x -> S:%d,", input, tran->Get_ID());
206         }
207         fprintf(f, "} ");
208 
209         if (s->_fail_link) {
210             fprintf(f, ", fail=S:%d", s->_fail_link->Get_ID());
211         }
212 
213         if (s->_is_terminal) {
214             fprintf(f, ", terminal");
215         }
216 
217         fprintf(f, "\n");
218     }
219     fclose(f);
220 }
221 
222 void
dump_dot(const char * dotfile) const223 ACS_Constructor::dump_dot(const char *dotfile) const {
224     FILE* f = fopen(dotfile, "w+");
225     const char* indent = "  ";
226 
227     fprintf(f, "digraph G {\n");
228 
229     // Emit node information
230     fprintf(f, "%s%d [style=filled];\n", indent, _root->Get_ID());
231     for (std::vector<ACS_State*>::const_iterator i = _all_states.begin(),
232             e = _all_states.end(); i != e; i++) {
233         ACS_State *s = *i;
234         if (s->_is_terminal) {
235             fprintf(f, "%s%d [shape=doublecircle];\n", indent, s->Get_ID());
236         }
237     }
238     fprintf(f, "\n");
239 
240     // Emit edge information
241     for (std::vector<ACS_State*>::const_iterator i = _all_states.begin(),
242             e = _all_states.end(); i != e; i++) {
243         ACS_State* s = *i;
244         uint32 id = s->Get_ID();
245 
246         const ACS_Goto_Map& m = s->Get_Goto_Map();
247         for (ACS_Goto_Map::const_iterator ii = m.begin(), ee = m.end();
248              ii != ee; ii++) {
249             InputTy input = ii->first;
250             ACS_State* tran = ii->second;
251             if (isalnum(input))
252                 fprintf(f, "%s%d -> %d [label=%c];\n",
253                         indent, id, tran->Get_ID(), input);
254             else
255                 fprintf(f, "%s%d -> %d [label=\"%#x\"];\n",
256                         indent, id, tran->Get_ID(), input);
257 
258         }
259 
260         // Emit fail-link
261         ACS_State* fl = s->Get_FailLink();
262         if (fl && fl != _root) {
263             fprintf(f, "%s%d -> %d [style=dotted, color=red]; \n",
264                     indent, id, fl->Get_ID());
265         }
266     }
267     fprintf(f, "}\n");
268     fclose(f);
269 }
270 #endif
271 
272 #ifdef VERIFY
273 void
Verify_Result(const char * subject,const Match_Result * r) const274 ACS_Constructor::Verify_Result(const char* subject, const Match_Result* r)
275     const {
276     if (r->begin >= 0) {
277         unsigned len = r->end - r->begin + 1;
278         int ptn_idx = r->pattern_idx;
279 
280         ASSERT(ptn_idx >= 0 &&
281                len == get_ith_Pattern_Len(ptn_idx) &&
282                memcmp(subject + r->begin, get_ith_Pattern(ptn_idx), len) == 0);
283     }
284 }
285 
286 void
Save_Patterns(const char ** strv,unsigned int * strlenv,int pattern_num)287 ACS_Constructor::Save_Patterns(const char** strv, unsigned int* strlenv,
288                                int pattern_num) {
289     // calculate the total size needed to save all patterns.
290     //
291     int buf_size = 0;
292     for (int i = 0; i < pattern_num; i++) { buf_size += strlenv[i]; }
293 
294     // HINT: patterns are delimited by '\0' in order to ease debugging.
295     buf_size += pattern_num;
296     ASSERT(_pattern_buf == 0);
297     _pattern_buf = new char[buf_size + 1];
298     #define MAGIC_NUM 0x5a
299     _pattern_buf[buf_size] = MAGIC_NUM;
300 
301     int ofst = 0;
302     _pattern_lens.resize(pattern_num);
303     _pattern_vect.resize(pattern_num);
304     for (int i = 0; i < pattern_num; i++) {
305         int l = strlenv[i];
306         _pattern_lens[i] = l;
307         _pattern_vect[i] = _pattern_buf + ofst;
308 
309         memcpy(_pattern_buf + ofst, strv[i], l);
310         ofst += l;
311         _pattern_buf[ofst++] = '\0';
312     }
313 
314     ASSERT(_pattern_buf[buf_size] == MAGIC_NUM);
315     #undef MAGIC_NUM
316 }
317 
318 #endif
319