1 #ifndef LM_FILTER_PHRASE_H
2 #define LM_FILTER_PHRASE_H
3
4 #include "util/murmur_hash.hh"
5 #include "util/string_piece.hh"
6 #include "util/tokenize_piece.hh"
7
8 #include <boost/unordered_map.hpp>
9
10 #include <iosfwd>
11 #include <vector>
12
13 #define LM_FILTER_PHRASE_METHOD(caps, lower) \
14 bool Find##caps(Hash key, const std::vector<unsigned int> *&out) const {\
15 Table::const_iterator i(table_.find(key));\
16 if (i==table_.end()) return false; \
17 out = &i->second.lower; \
18 return true; \
19 }
20
21 namespace lm {
22 namespace phrase {
23
24 typedef uint64_t Hash;
25
26 class Substrings {
27 private:
28 /* This is the value in a hash table where the key is a string. It indicates
29 * four sets of sentences:
30 * substring is sentences with a phrase containing the key as a substring.
31 * left is sentencess with a phrase that begins with the key (left aligned).
32 * right is sentences with a phrase that ends with the key (right aligned).
33 * phrase is sentences where the key is a phrase.
34 * Each set is encoded as a vector of sentence ids in increasing order.
35 */
36 struct SentenceRelation {
37 std::vector<unsigned int> substring, left, right, phrase;
38 };
39 /* Most of the CPU is hash table lookups, so let's not complicate it with
40 * vector equality comparisons. If a collision happens, the SentenceRelation
41 * structure will contain the union of sentence ids over the colliding strings.
42 * In that case, the filter will be slightly more permissive.
43 * The key here is the same as boost's hash of std::vector<std::string>.
44 */
45 typedef boost::unordered_map<Hash, SentenceRelation> Table;
46
47 public:
Substrings()48 Substrings() {}
49
50 /* If the string isn't a substring of any phrase, return NULL. Otherwise,
51 * return a pointer to std::vector<unsigned int> listing sentences with
52 * matching phrases. This set may be empty for Left, Right, or Phrase.
53 * Example: const std::vector<unsigned int> *FindSubstring(Hash key)
54 */
LM_FILTER_PHRASE_METHOD(Substring,substring)55 LM_FILTER_PHRASE_METHOD(Substring, substring)
56 LM_FILTER_PHRASE_METHOD(Left, left)
57 LM_FILTER_PHRASE_METHOD(Right, right)
58 LM_FILTER_PHRASE_METHOD(Phrase, phrase)
59
60 #pragma GCC diagnostic ignored "-Wuninitialized" // end != finish so there's always an initialization
61 // sentence_id must be non-decreasing. Iterators are over words in the phrase.
62 template <class Iterator> void AddPhrase(unsigned int sentence_id, const Iterator &begin, const Iterator &end) {
63 // Iterate over all substrings.
64 for (Iterator start = begin; start != end; ++start) {
65 Hash hash = 0;
66 SentenceRelation *relation;
67 for (Iterator finish = start; finish != end; ++finish) {
68 hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *finish);
69 // Now hash is of [start, finish].
70 relation = &table_[hash];
71 AppendSentence(relation->substring, sentence_id);
72 if (start == begin) AppendSentence(relation->left, sentence_id);
73 }
74 AppendSentence(relation->right, sentence_id);
75 if (start == begin) AppendSentence(relation->phrase, sentence_id);
76 }
77 }
78
79 private:
AppendSentence(std::vector<unsigned int> & vec,unsigned int sentence_id)80 void AppendSentence(std::vector<unsigned int> &vec, unsigned int sentence_id) {
81 if (vec.empty() || vec.back() != sentence_id) vec.push_back(sentence_id);
82 }
83
84 Table table_;
85 };
86
87 // Read a file with one sentence per line containing tab-delimited phrases of
88 // space-separated words.
89 unsigned int ReadMultiple(std::istream &in, Substrings &out);
90
91 namespace detail {
92 extern const StringPiece kEndSentence;
93
MakeHashes(Iterator i,const Iterator & end,std::vector<Hash> & hashes)94 template <class Iterator> void MakeHashes(Iterator i, const Iterator &end, std::vector<Hash> &hashes) {
95 hashes.clear();
96 if (i == end) return;
97 // TODO: check strict phrase boundaries after <s> and before </s>. For now, just skip tags.
98 if ((i->data()[0] == '<') && (i->data()[i->size() - 1] == '>')) {
99 ++i;
100 }
101 for (; i != end && (*i != kEndSentence); ++i) {
102 hashes.push_back(util::MurmurHashNative(i->data(), i->size()));
103 }
104 }
105
106 class Vertex;
107 class Arc;
108
109 class ConditionCommon {
110 protected:
111 ConditionCommon(const Substrings &substrings);
112 ConditionCommon(const ConditionCommon &from);
113
114 ~ConditionCommon();
115
116 detail::Vertex &MakeGraph();
117
118 // Temporaries in PassNGram and Evaluate to avoid reallocation.
119 std::vector<Hash> hashes_;
120
121 private:
122 std::vector<detail::Vertex> vertices_;
123 std::vector<detail::Arc> arcs_;
124
125 const Substrings &substrings_;
126 };
127
128 } // namespace detail
129
130 class Union : public detail::ConditionCommon {
131 public:
Union(const Substrings & substrings)132 explicit Union(const Substrings &substrings) : detail::ConditionCommon(substrings) {}
133
PassNGram(const Iterator & begin,const Iterator & end)134 template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
135 detail::MakeHashes(begin, end, hashes_);
136 return hashes_.empty() || Evaluate();
137 }
138
139 private:
140 bool Evaluate();
141 };
142
143 class Multiple : public detail::ConditionCommon {
144 public:
Multiple(const Substrings & substrings)145 explicit Multiple(const Substrings &substrings) : detail::ConditionCommon(substrings) {}
146
AddNGram(const Iterator & begin,const Iterator & end,const StringPiece & line,Output & output)147 template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
148 detail::MakeHashes(begin, end, hashes_);
149 if (hashes_.empty()) {
150 output.AddNGram(line);
151 } else {
152 Evaluate(line, output);
153 }
154 }
155
AddNGram(const StringPiece & ngram,const StringPiece & line,Output & output)156 template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
157 AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
158 }
159
Flush() const160 void Flush() const {}
161
162 private:
163 template <class Output> void Evaluate(const StringPiece &line, Output &output);
164 };
165
166 } // namespace phrase
167 } // namespace lm
168 #endif // LM_FILTER_PHRASE_H
169