1 #include <ctype.h>
2 #include <strings.h> // for bzero
3 #include <algorithm>
4 #include "ac_slow.hpp"
5 #include "ac.h"
6
7 //////////////////////////////////////////////////////////////////////////
8 //
9 // Implementation of AhoCorasick_Slow
10 //
11 //////////////////////////////////////////////////////////////////////////
12 //
ACS_Constructor()13 ACS_Constructor::ACS_Constructor() : _next_node_id(1) {
14 _root = new_state();
15 _root_char = new InputTy[256];
16 bzero((void*)_root_char, 256);
17
18 #ifdef VERIFY
19 _pattern_buf = 0;
20 #endif
21 }
22
~ACS_Constructor()23 ACS_Constructor::~ACS_Constructor() {
24 for (std::vector<ACS_State* >::iterator i = _all_states.begin(),
25 e = _all_states.end(); i != e; i++) {
26 delete *i;
27 }
28 _all_states.clear();
29 delete[] _root_char;
30
31 #ifdef VERIFY
32 delete[] _pattern_buf;
33 #endif
34 }
35
36 ACS_State*
new_state()37 ACS_Constructor::new_state() {
38 ACS_State* t = new ACS_State(_next_node_id++);
39 _all_states.push_back(t);
40 return t;
41 }
42
43 void
Add_Pattern(const char * str,unsigned int str_len,int pattern_idx)44 ACS_Constructor::Add_Pattern(const char* str, unsigned int str_len,
45 int pattern_idx) {
46 ACS_State* state = _root;
47 for (unsigned int i = 0; i < str_len; i++) {
48 const char c = str[i];
49 ACS_State* new_s = state->Get_Goto(c);
50 if (!new_s) {
51 new_s = new_state();
52 new_s->_depth = state->_depth + 1;
53 state->Set_Goto(c, new_s);
54 }
55 state = new_s;
56 }
57 state->_is_terminal = true;
58 state->set_Pattern_Idx(pattern_idx);
59 }
60
61 void
Propagate_faillink()62 ACS_Constructor::Propagate_faillink() {
63 ACS_State* r = _root;
64 std::vector<ACS_State*> wl;
65
66 const ACS_Goto_Map& m = r->Get_Goto_Map();
67 for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end(); i != e; i++) {
68 ACS_State* s = i->second;
69 s->_fail_link = r;
70 wl.push_back(s);
71 }
72
73 // For any input c, make sure "goto(root, c)" is valid, which make the
74 // fail-link propagation lot easier.
75 ACS_Goto_Map goto_save = r->_goto_map;
76 for (uint32 i = 0; i <= 255; i++) {
77 ACS_State* s = r->Get_Goto(i);
78 if (!s) r->Set_Goto(i, r);
79 }
80
81 for (uint32 i = 0; i < wl.size(); i++) {
82 ACS_State* s = wl[i];
83 ACS_State* fl = s->_fail_link;
84
85 const ACS_Goto_Map& tran_map = s->Get_Goto_Map();
86
87 for (ACS_Goto_Map::const_iterator ii = tran_map.begin(),
88 ee = tran_map.end(); ii != ee; ii++) {
89 InputTy c = ii->first;
90 ACS_State *tran = ii->second;
91
92 ACS_State* tran_fl = 0;
93 for (ACS_State* fl_walk = fl; ;) {
94 if (ACS_State* t = fl_walk->Get_Goto(c)) {
95 tran_fl = t;
96 break;
97 } else {
98 fl_walk = fl_walk->Get_FailLink();
99 }
100 }
101
102 tran->_fail_link = tran_fl;
103 wl.push_back(tran);
104 }
105 }
106
107 // Remove "goto(root, c) == root" transitions
108 r->_goto_map = goto_save;
109 }
110
111 void
Construct(const char ** strv,unsigned int * strlenv,uint32 strnum)112 ACS_Constructor::Construct(const char** strv, unsigned int* strlenv,
113 uint32 strnum) {
114 Save_Patterns(strv, strlenv, strnum);
115
116 for (uint32 i = 0; i < strnum; i++) {
117 Add_Pattern(strv[i], strlenv[i], i);
118 }
119
120 Propagate_faillink();
121 unsigned char* p = _root_char;
122
123 const ACS_Goto_Map& m = _root->Get_Goto_Map();
124 for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end();
125 i != e; i++) {
126 p[i->first] = 1;
127 }
128 }
129
130 Match_Result
MatchHelper(const char * str,uint32 len) const131 ACS_Constructor::MatchHelper(const char *str, uint32 len) const {
132 const ACS_State* root = _root;
133 const ACS_State* state = root;
134
135 uint32 idx = 0;
136 while (idx < len) {
137 InputTy c = str[idx];
138 idx++;
139 if (_root_char[c]) {
140 state = root->Get_Goto(c);
141 break;
142 }
143 }
144
145 if (unlikely(state->is_Terminal())) {
146 // This could happen if the one of the pattern has only one char!
147 uint32 pos = idx - 1;
148 Match_Result r(pos - state->Get_Depth() + 1, pos,
149 state->get_Pattern_Idx());
150 return r;
151 }
152
153 while (idx < len) {
154 InputTy c = str[idx];
155 ACS_State* gs = state->Get_Goto(c);
156
157 if (!gs) {
158 ACS_State* fl = state->Get_FailLink();
159 if (fl == root) {
160 while (idx < len) {
161 InputTy c = str[idx];
162 idx++;
163 if (_root_char[c]) {
164 state = root->Get_Goto(c);
165 break;
166 }
167 }
168 } else {
169 state = fl;
170 }
171 } else {
172 idx ++;
173 state = gs;
174 }
175
176 if (state->is_Terminal()) {
177 uint32 pos = idx - 1;
178 Match_Result r = Match_Result(pos - state->Get_Depth() + 1, pos,
179 state->get_Pattern_Idx());
180 return r;
181 }
182 }
183
184 return Match_Result(-1, -1, -1);
185 }
186
187 #ifdef DEBUG
188 void
dump_text(const char * txtfile) const189 ACS_Constructor::dump_text(const char* txtfile) const {
190 FILE* f = fopen(txtfile, "w+");
191 for (std::vector<ACS_State*>::const_iterator i = _all_states.begin(),
192 e = _all_states.end(); i != e; i++) {
193 ACS_State* s = *i;
194
195 fprintf(f, "S%d goto:{", s->Get_ID());
196 const ACS_Goto_Map& goto_func = s->Get_Goto_Map();
197
198 for (ACS_Goto_Map::const_iterator i = goto_func.begin(), e = goto_func.end();
199 i != e; i++) {
200 InputTy input = i->first;
201 ACS_State* tran = i->second;
202 if (isprint(input))
203 fprintf(f, "'%c' -> S:%d,", input, tran->Get_ID());
204 else
205 fprintf(f, "%#x -> S:%d,", input, tran->Get_ID());
206 }
207 fprintf(f, "} ");
208
209 if (s->_fail_link) {
210 fprintf(f, ", fail=S:%d", s->_fail_link->Get_ID());
211 }
212
213 if (s->_is_terminal) {
214 fprintf(f, ", terminal");
215 }
216
217 fprintf(f, "\n");
218 }
219 fclose(f);
220 }
221
222 void
dump_dot(const char * dotfile) const223 ACS_Constructor::dump_dot(const char *dotfile) const {
224 FILE* f = fopen(dotfile, "w+");
225 const char* indent = " ";
226
227 fprintf(f, "digraph G {\n");
228
229 // Emit node information
230 fprintf(f, "%s%d [style=filled];\n", indent, _root->Get_ID());
231 for (std::vector<ACS_State*>::const_iterator i = _all_states.begin(),
232 e = _all_states.end(); i != e; i++) {
233 ACS_State *s = *i;
234 if (s->_is_terminal) {
235 fprintf(f, "%s%d [shape=doublecircle];\n", indent, s->Get_ID());
236 }
237 }
238 fprintf(f, "\n");
239
240 // Emit edge information
241 for (std::vector<ACS_State*>::const_iterator i = _all_states.begin(),
242 e = _all_states.end(); i != e; i++) {
243 ACS_State* s = *i;
244 uint32 id = s->Get_ID();
245
246 const ACS_Goto_Map& m = s->Get_Goto_Map();
247 for (ACS_Goto_Map::const_iterator ii = m.begin(), ee = m.end();
248 ii != ee; ii++) {
249 InputTy input = ii->first;
250 ACS_State* tran = ii->second;
251 if (isalnum(input))
252 fprintf(f, "%s%d -> %d [label=%c];\n",
253 indent, id, tran->Get_ID(), input);
254 else
255 fprintf(f, "%s%d -> %d [label=\"%#x\"];\n",
256 indent, id, tran->Get_ID(), input);
257
258 }
259
260 // Emit fail-link
261 ACS_State* fl = s->Get_FailLink();
262 if (fl && fl != _root) {
263 fprintf(f, "%s%d -> %d [style=dotted, color=red]; \n",
264 indent, id, fl->Get_ID());
265 }
266 }
267 fprintf(f, "}\n");
268 fclose(f);
269 }
270 #endif
271
272 #ifdef VERIFY
273 void
Verify_Result(const char * subject,const Match_Result * r) const274 ACS_Constructor::Verify_Result(const char* subject, const Match_Result* r)
275 const {
276 if (r->begin >= 0) {
277 unsigned len = r->end - r->begin + 1;
278 int ptn_idx = r->pattern_idx;
279
280 ASSERT(ptn_idx >= 0 &&
281 len == get_ith_Pattern_Len(ptn_idx) &&
282 memcmp(subject + r->begin, get_ith_Pattern(ptn_idx), len) == 0);
283 }
284 }
285
286 void
Save_Patterns(const char ** strv,unsigned int * strlenv,int pattern_num)287 ACS_Constructor::Save_Patterns(const char** strv, unsigned int* strlenv,
288 int pattern_num) {
289 // calculate the total size needed to save all patterns.
290 //
291 int buf_size = 0;
292 for (int i = 0; i < pattern_num; i++) { buf_size += strlenv[i]; }
293
294 // HINT: patterns are delimited by '\0' in order to ease debugging.
295 buf_size += pattern_num;
296 ASSERT(_pattern_buf == 0);
297 _pattern_buf = new char[buf_size + 1];
298 #define MAGIC_NUM 0x5a
299 _pattern_buf[buf_size] = MAGIC_NUM;
300
301 int ofst = 0;
302 _pattern_lens.resize(pattern_num);
303 _pattern_vect.resize(pattern_num);
304 for (int i = 0; i < pattern_num; i++) {
305 int l = strlenv[i];
306 _pattern_lens[i] = l;
307 _pattern_vect[i] = _pattern_buf + ofst;
308
309 memcpy(_pattern_buf + ofst, strv[i], l);
310 ofst += l;
311 _pattern_buf[ofst++] = '\0';
312 }
313
314 ASSERT(_pattern_buf[buf_size] == MAGIC_NUM);
315 #undef MAGIC_NUM
316 }
317
318 #endif
319