1 #ifndef MY_AC_H 2 #define MY_AC_H 3 4 #include <string.h> 5 #include <stdio.h> 6 #include <map> 7 #include <vector> 8 #include <algorithm> // for std::sort 9 #include "ac_util.hpp" 10 11 // Forward decl. the acronym "ACS" stands for "Aho-Corasick Slow implementation" 12 class ACS_State; 13 class ACS_Constructor; 14 class AhoCorasick; 15 16 using namespace std; 17 18 typedef std::map<InputTy, ACS_State*> ACS_Goto_Map; 19 20 class Match_Result { 21 public: 22 int begin; 23 int end; 24 int pattern_idx; Match_Result(int b,int e,int p)25 Match_Result(int b, int e, int p): begin(b), end(e), pattern_idx(p) {} 26 }; 27 28 typedef pair<InputTy, ACS_State *> GotoPair; 29 typedef vector<GotoPair> GotoVect; 30 31 // Sorting functor 32 class GotoSort { 33 public: operator ()(const GotoPair & g1,const GotoPair & g2)34 bool operator() (const GotoPair& g1, const GotoPair& g2) { 35 return g1.first < g2.first; 36 } 37 }; 38 39 class ACS_State { 40 friend class ACS_Constructor; 41 42 public: ACS_State(uint32 id)43 ACS_State(uint32 id): _id(id), _pattern_idx(-1), _depth(0), 44 _is_terminal(false), _fail_link(0){} ~ACS_State()45 ~ACS_State() {}; 46 Set_Goto(InputTy c,ACS_State * s)47 void Set_Goto(InputTy c, ACS_State* s) { _goto_map[c] = s; } Get_Goto(InputTy c) const48 ACS_State *Get_Goto(InputTy c) const { 49 ACS_Goto_Map::const_iterator iter = _goto_map.find(c); 50 return iter != _goto_map.end() ? (*iter).second : 0; 51 } 52 53 // Return all transitions sorted in the ascending order of their input. Get_Sorted_Gotos(GotoVect & Gotos) const54 void Get_Sorted_Gotos(GotoVect& Gotos) const { 55 const ACS_Goto_Map& m = _goto_map; 56 Gotos.clear(); 57 for (ACS_Goto_Map::const_iterator i = m.begin(), e = m.end(); 58 i != e; i++) { 59 Gotos.push_back(GotoPair(i->first, i->second)); 60 } 61 sort(Gotos.begin(), Gotos.end(), GotoSort()); 62 } 63 Get_FailLink() const64 ACS_State* Get_FailLink() const { return _fail_link; } Get_GotoNum() const65 uint32 Get_GotoNum() const { return _goto_map.size(); } Get_ID() const66 uint32 Get_ID() const { return _id; } Get_Depth() const67 uint32 Get_Depth() const { return _depth; } Get_Goto_Map(void) const68 const ACS_Goto_Map& Get_Goto_Map(void) const { return _goto_map; } is_Terminal() const69 bool is_Terminal() const { return _is_terminal; } get_Pattern_Idx() const70 int get_Pattern_Idx() const { 71 ASSERT(is_Terminal() && _pattern_idx >= 0); 72 return _pattern_idx; 73 } 74 75 private: set_Pattern_Idx(int idx)76 void set_Pattern_Idx(int idx) { 77 ASSERT(is_Terminal()); 78 _pattern_idx = idx; 79 } 80 81 private: 82 uint32 _id; 83 int _pattern_idx; 84 short _depth; 85 bool _is_terminal; 86 ACS_Goto_Map _goto_map; 87 ACS_State* _fail_link; 88 }; 89 90 class ACS_Constructor { 91 public: 92 ACS_Constructor(); 93 ~ACS_Constructor(); 94 95 void Construct(const char** strv, unsigned int* strlenv, 96 unsigned int strnum); 97 Match(const char * s,uint32 len) const98 Match_Result Match(const char* s, uint32 len) const { 99 Match_Result r = MatchHelper(s, len); 100 Verify_Result(s, &r); 101 return r; 102 } 103 Match(const char * s) const104 Match_Result Match(const char* s) const { return Match(s, strlen(s)); } 105 106 #ifdef DEBUG 107 void dump_text(const char* = "ac.txt") const; 108 void dump_dot(const char* = "ac.dot") const; 109 #endif Get_Root_State() const110 const ACS_State *Get_Root_State() const { return _root; } Get_All_States() const111 const vector<ACS_State*>& Get_All_States() const { 112 return _all_states; 113 } 114 Get_Next_Node_Id() const115 uint32 Get_Next_Node_Id() const { return _next_node_id; } Get_State_Num() const116 uint32 Get_State_Num() const { return _next_node_id - 1; } 117 118 private: 119 void Add_Pattern(const char* str, unsigned int str_len, int pattern_idx); 120 ACS_State* new_state(); 121 void Propagate_faillink(); 122 123 Match_Result MatchHelper(const char*, uint32 len) const; 124 125 #ifdef VERIFY 126 void Verify_Result(const char* subject, const Match_Result* r) const; 127 void Save_Patterns(const char** strv, unsigned int* strlenv, int vect_len); get_ith_Pattern(unsigned i) const128 const char* get_ith_Pattern(unsigned i) const { 129 ASSERT(i < _pattern_vect.size()); 130 return _pattern_vect.at(i); 131 } get_ith_Pattern_Len(unsigned i) const132 unsigned get_ith_Pattern_Len(unsigned i) const { 133 ASSERT(i < _pattern_lens.size()); 134 return _pattern_lens.at(i); 135 } 136 #else Verify_Result(const char * subject,const Match_Result * r) const137 void Verify_Result(const char* subject, const Match_Result* r) const { 138 (void)subject; (void)r; 139 } Save_Patterns(const char ** strv,unsigned int * strlenv,int vect_len)140 void Save_Patterns(const char** strv, unsigned int* strlenv, int vect_len) { 141 (void)strv; (void)strlenv; 142 } 143 #endif 144 145 private: 146 ACS_State* _root; 147 vector<ACS_State*> _all_states; 148 unsigned char* _root_char; 149 uint32 _next_node_id; 150 151 #ifdef VERIFY 152 char* _pattern_buf; 153 vector<int> _pattern_lens; 154 vector<char*> _pattern_vect; 155 #endif 156 }; 157 158 #endif 159