1 #ifndef MY_AC_H
2 #define MY_AC_H
3 
4 #include <string.h>
5 #include <stdio.h>
6 #include <map>
7 #include <vector>
8 #include <algorithm> // for std::sort
9 #include "ac_util.hpp"
10 
11 // Forward decl. the acronym "ACS" stands for "Aho-Corasick Slow implementation"
12 class ACS_State;
13 class ACS_Constructor;
14 class AhoCorasick;
15 
16 using namespace std;
17 
18 typedef std::map<InputTy, ACS_State*> ACS_Goto_Map;
19 
20 class Match_Result {
21 public:
22     int begin;
23     int end;
24     int pattern_idx;
Match_Result(int b,int e,int p)25     Match_Result(int b, int e, int p): begin(b), end(e), pattern_idx(p) {}
26 };
27 
28 typedef pair<InputTy, ACS_State *> GotoPair;
29 typedef vector<GotoPair> GotoVect;
30 
31 // Sorting functor
32 class GotoSort {
33 public:
operator ()(const GotoPair & g1,const GotoPair & g2)34     bool operator() (const GotoPair& g1, const GotoPair& g2) {
35         return g1.first < g2.first;
36     }
37 };
38 
39 class ACS_State {
40 friend class ACS_Constructor;
41 
42 public:
ACS_State(uint32 id)43     ACS_State(uint32 id): _id(id), _pattern_idx(-1), _depth(0),
44                           _is_terminal(false), _fail_link(0){}
~ACS_State()45     ~ACS_State() {};
46 
Set_Goto(InputTy c,ACS_State * s)47     void Set_Goto(InputTy c, ACS_State* s) { _goto_map[c] = s; }
Get_Goto(InputTy c) const48     ACS_State *Get_Goto(InputTy c) const {
49         ACS_Goto_Map::const_iterator iter = _goto_map.find(c);
50         return iter != _goto_map.end() ? (*iter).second : 0;
51     }
52 
53     // Return all transitions sorted in the ascending order of their input.
Get_Sorted_Gotos(GotoVect & Gotos) const54     void Get_Sorted_Gotos(GotoVect& Gotos) const {
55         const ACS_Goto_Map& m = _goto_map;
56         Gotos.clear();
57         for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end();
58                 i != e; i++) {
59             Gotos.push_back(GotoPair(i->first, i->second));
60         }
61         sort(Gotos.begin(), Gotos.end(), GotoSort());
62     }
63 
Get_FailLink() const64     ACS_State* Get_FailLink() const { return _fail_link; }
Get_GotoNum() const65     uint32 Get_GotoNum() const { return _goto_map.size(); }
Get_ID() const66     uint32 Get_ID() const { return _id; }
Get_Depth() const67     uint32 Get_Depth() const { return _depth; }
Get_Goto_Map(void) const68     const ACS_Goto_Map& Get_Goto_Map(void) const { return _goto_map; }
is_Terminal() const69     bool is_Terminal() const { return _is_terminal; }
get_Pattern_Idx() const70     int get_Pattern_Idx() const {
71         ASSERT(is_Terminal() && _pattern_idx >= 0);
72         return _pattern_idx;
73     }
74 
75 private:
set_Pattern_Idx(int idx)76     void set_Pattern_Idx(int idx) {
77         ASSERT(is_Terminal());
78         _pattern_idx = idx;
79     }
80 
81 private:
82     uint32 _id;
83     int _pattern_idx;
84     short _depth;
85     bool _is_terminal;
86     ACS_Goto_Map _goto_map;
87     ACS_State* _fail_link;
88 };
89 
90 class ACS_Constructor {
91 public:
92     ACS_Constructor();
93     ~ACS_Constructor();
94 
95     void Construct(const char** strv, unsigned int* strlenv,
96                    unsigned int strnum);
97 
Match(const char * s,uint32 len) const98     Match_Result Match(const char* s, uint32 len) const {
99         Match_Result r = MatchHelper(s, len);
100         Verify_Result(s, &r);
101         return r;
102     }
103 
Match(const char * s) const104     Match_Result Match(const char* s) const { return Match(s, strlen(s)); }
105 
106 #ifdef DEBUG
107     void dump_text(const char* = "ac.txt") const;
108     void dump_dot(const char* = "ac.dot") const;
109 #endif
Get_Root_State() const110     const ACS_State *Get_Root_State() const { return _root; }
Get_All_States() const111     const vector<ACS_State*>& Get_All_States() const {
112         return _all_states;
113     }
114 
Get_Next_Node_Id() const115     uint32 Get_Next_Node_Id() const { return _next_node_id; }
Get_State_Num() const116     uint32 Get_State_Num() const { return _next_node_id - 1; }
117 
118 private:
119     void Add_Pattern(const char* str, unsigned int str_len, int pattern_idx);
120     ACS_State* new_state();
121     void Propagate_faillink();
122 
123     Match_Result MatchHelper(const char*, uint32 len) const;
124 
125 #ifdef VERIFY
126     void Verify_Result(const char* subject, const Match_Result* r) const;
127     void Save_Patterns(const char** strv, unsigned int* strlenv, int vect_len);
get_ith_Pattern(unsigned i) const128     const char* get_ith_Pattern(unsigned i) const {
129         ASSERT(i < _pattern_vect.size());
130         return _pattern_vect.at(i);
131     }
get_ith_Pattern_Len(unsigned i) const132     unsigned get_ith_Pattern_Len(unsigned i) const {
133         ASSERT(i < _pattern_lens.size());
134         return _pattern_lens.at(i);
135     }
136 #else
Verify_Result(const char * subject,const Match_Result * r) const137     void Verify_Result(const char* subject, const Match_Result* r) const {
138         (void)subject; (void)r;
139     }
Save_Patterns(const char ** strv,unsigned int * strlenv,int vect_len)140     void Save_Patterns(const char** strv, unsigned int* strlenv, int vect_len) {
141         (void)strv; (void)strlenv;
142     }
143 #endif
144 
145 private:
146     ACS_State* _root;
147     vector<ACS_State*> _all_states;
148     unsigned char* _root_char;
149     uint32 _next_node_id;
150 
151 #ifdef VERIFY
152     char* _pattern_buf;
153     vector<int> _pattern_lens;
154     vector<char*> _pattern_vect;
155 #endif
156 };
157 
158 #endif
159