1 #ifndef LM_FILTER_PHRASE_H
2 #define LM_FILTER_PHRASE_H
3 
4 #include "util/murmur_hash.hh"
5 #include "util/string_piece.hh"
6 #include "util/tokenize_piece.hh"
7 
8 #include <boost/unordered_map.hpp>
9 
10 #include <iosfwd>
11 #include <vector>
12 
13 #define LM_FILTER_PHRASE_METHOD(caps, lower) \
14 bool Find##caps(Hash key, const std::vector<unsigned int> *&out) const {\
15   Table::const_iterator i(table_.find(key));\
16   if (i==table_.end()) return false; \
17   out = &i->second.lower; \
18   return true; \
19 }
20 
21 namespace lm {
22 namespace phrase {
23 
24 typedef uint64_t Hash;
25 
26 class Substrings {
27   private:
28     /* This is the value in a hash table where the key is a string.  It indicates
29      * four sets of sentences:
30      * substring is sentences with a phrase containing the key as a substring.
31      * left is sentencess with a phrase that begins with the key (left aligned).
32      * right is sentences with a phrase that ends with the key (right aligned).
33      * phrase is sentences where the key is a phrase.
34      * Each set is encoded as a vector of sentence ids in increasing order.
35      */
36     struct SentenceRelation {
37       std::vector<unsigned int> substring, left, right, phrase;
38     };
39     /* Most of the CPU is hash table lookups, so let's not complicate it with
40      * vector equality comparisons.  If a collision happens, the SentenceRelation
41      * structure will contain the union of sentence ids over the colliding strings.
42      * In that case, the filter will be slightly more permissive.
43      * The key here is the same as boost's hash of std::vector<std::string>.
44      */
45     typedef boost::unordered_map<Hash, SentenceRelation> Table;
46 
47   public:
Substrings()48     Substrings() {}
49 
50     /* If the string isn't a substring of any phrase, return NULL.  Otherwise,
51      * return a pointer to std::vector<unsigned int> listing sentences with
52      * matching phrases.  This set may be empty for Left, Right, or Phrase.
53      * Example: const std::vector<unsigned int> *FindSubstring(Hash key)
54      */
LM_FILTER_PHRASE_METHOD(Substring,substring)55     LM_FILTER_PHRASE_METHOD(Substring, substring)
56     LM_FILTER_PHRASE_METHOD(Left, left)
57     LM_FILTER_PHRASE_METHOD(Right, right)
58     LM_FILTER_PHRASE_METHOD(Phrase, phrase)
59 
60 #pragma GCC diagnostic ignored "-Wuninitialized" // end != finish so there's always an initialization
61     // sentence_id must be non-decreasing.  Iterators are over words in the phrase.
62     template <class Iterator> void AddPhrase(unsigned int sentence_id, const Iterator &begin, const Iterator &end) {
63       // Iterate over all substrings.
64       for (Iterator start = begin; start != end; ++start) {
65         Hash hash = 0;
66         SentenceRelation *relation;
67         for (Iterator finish = start; finish != end; ++finish) {
68           hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *finish);
69           // Now hash is of [start, finish].
70           relation = &table_[hash];
71           AppendSentence(relation->substring, sentence_id);
72           if (start == begin) AppendSentence(relation->left, sentence_id);
73         }
74         AppendSentence(relation->right, sentence_id);
75         if (start == begin) AppendSentence(relation->phrase, sentence_id);
76       }
77     }
78 
79   private:
AppendSentence(std::vector<unsigned int> & vec,unsigned int sentence_id)80     void AppendSentence(std::vector<unsigned int> &vec, unsigned int sentence_id) {
81       if (vec.empty() || vec.back() != sentence_id) vec.push_back(sentence_id);
82     }
83 
84     Table table_;
85 };
86 
87 // Read a file with one sentence per line containing tab-delimited phrases of
88 // space-separated words.
89 unsigned int ReadMultiple(std::istream &in, Substrings &out);
90 
91 namespace detail {
92 extern const StringPiece kEndSentence;
93 
MakeHashes(Iterator i,const Iterator & end,std::vector<Hash> & hashes)94 template <class Iterator> void MakeHashes(Iterator i, const Iterator &end, std::vector<Hash> &hashes) {
95   hashes.clear();
96   if (i == end) return;
97   // TODO: check strict phrase boundaries after <s> and before </s>.  For now, just skip tags.
98   if ((i->data()[0] == '<') && (i->data()[i->size() - 1] == '>')) {
99     ++i;
100   }
101   for (; i != end && (*i != kEndSentence); ++i) {
102     hashes.push_back(util::MurmurHashNative(i->data(), i->size()));
103   }
104 }
105 
106 class Vertex;
107 class Arc;
108 
109 class ConditionCommon {
110   protected:
111     ConditionCommon(const Substrings &substrings);
112     ConditionCommon(const ConditionCommon &from);
113 
114     ~ConditionCommon();
115 
116     detail::Vertex &MakeGraph();
117 
118     // Temporaries in PassNGram and Evaluate to avoid reallocation.
119     std::vector<Hash> hashes_;
120 
121   private:
122     std::vector<detail::Vertex> vertices_;
123     std::vector<detail::Arc> arcs_;
124 
125     const Substrings &substrings_;
126 };
127 
128 } // namespace detail
129 
130 class Union : public detail::ConditionCommon {
131   public:
Union(const Substrings & substrings)132     explicit Union(const Substrings &substrings) : detail::ConditionCommon(substrings) {}
133 
PassNGram(const Iterator & begin,const Iterator & end)134     template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
135       detail::MakeHashes(begin, end, hashes_);
136       return hashes_.empty() || Evaluate();
137     }
138 
139   private:
140     bool Evaluate();
141 };
142 
143 class Multiple : public detail::ConditionCommon {
144   public:
Multiple(const Substrings & substrings)145     explicit Multiple(const Substrings &substrings) : detail::ConditionCommon(substrings) {}
146 
AddNGram(const Iterator & begin,const Iterator & end,const StringPiece & line,Output & output)147     template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
148       detail::MakeHashes(begin, end, hashes_);
149       if (hashes_.empty()) {
150         output.AddNGram(line);
151       } else {
152         Evaluate(line, output);
153       }
154     }
155 
AddNGram(const StringPiece & ngram,const StringPiece & line,Output & output)156     template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
157       AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
158     }
159 
Flush() const160     void Flush() const {}
161 
162   private:
163     template <class Output> void Evaluate(const StringPiece &line, Output &output);
164 };
165 
166 } // namespace phrase
167 } // namespace lm
168 #endif // LM_FILTER_PHRASE_H
169