1 #include "lm/model.hh"
2
3 #include <cstdlib>
4 #include <cstring>
5
6 #define BOOST_TEST_MODULE ModelTest
7 #include <boost/test/unit_test.hpp>
8 #include <boost/test/floating_point_comparison.hpp>
9
10 // Apparently some Boost versions use templates and are pretty strict about types matching.
11 #define SLOPPY_CHECK_CLOSE(ref, value, tol) BOOST_CHECK_CLOSE(static_cast<double>(ref), static_cast<double>(value), static_cast<double>(tol));
12
13 namespace lm {
14 namespace ngram {
15
operator <<(std::ostream & o,const State & state)16 std::ostream &operator<<(std::ostream &o, const State &state) {
17 o << "State length " << static_cast<unsigned int>(state.length) << ':';
18 for (const WordIndex *i = state.words; i < state.words + state.length; ++i) {
19 o << ' ' << *i;
20 }
21 return o;
22 }
23
24 namespace {
25
26 // Stupid bjam reverses the command line arguments randomly.
TestLocation()27 const char *TestLocation() {
28 if (boost::unit_test::framework::master_test_suite().argc < 3) {
29 return "test.arpa";
30 }
31 char **argv = boost::unit_test::framework::master_test_suite().argv;
32 return argv[strstr(argv[1], "nounk") ? 2 : 1];
33 }
TestNoUnkLocation()34 const char *TestNoUnkLocation() {
35 if (boost::unit_test::framework::master_test_suite().argc < 3) {
36 return "test_nounk.arpa";
37 }
38 char **argv = boost::unit_test::framework::master_test_suite().argv;
39 return argv[strstr(argv[1], "nounk") ? 1 : 2];
40 }
41
GetState(const Model & model,const char * word,const State & in)42 template <class Model> State GetState(const Model &model, const char *word, const State &in) {
43 WordIndex context[in.length + 1];
44 context[0] = model.GetVocabulary().Index(word);
45 std::copy(in.words, in.words + in.length, context + 1);
46 State ret;
47 model.GetState(context, context + in.length + 1, ret);
48 return ret;
49 }
50
51 #define StartTest(word, ngram, score, indep_left) \
52 ret = model.FullScore( \
53 state, \
54 model.GetVocabulary().Index(word), \
55 out);\
56 SLOPPY_CHECK_CLOSE(score, ret.prob, 0.001); \
57 BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
58 BOOST_CHECK_GE(std::min<unsigned char>(ngram, 5 - 1), out.length); \
59 BOOST_CHECK_EQUAL(indep_left, ret.independent_left); \
60 BOOST_CHECK_EQUAL(out, GetState(model, word, state));
61
62 #define AppendTest(word, ngram, score, indep_left) \
63 StartTest(word, ngram, score, indep_left) \
64 state = out;
65
Starters(const M & model)66 template <class M> void Starters(const M &model) {
67 FullScoreReturn ret;
68 Model::State state(model.BeginSentenceState());
69 Model::State out;
70
71 StartTest("looking", 2, -0.4846522, true);
72
73 // , probability plus <s> backoff
74 StartTest(",", 1, -1.383514 + -0.4149733, true);
75 // <unk> probability plus <s> backoff
76 StartTest("this_is_not_found", 1, -1.995635 + -0.4149733, true);
77 }
78
Continuation(const M & model)79 template <class M> void Continuation(const M &model) {
80 FullScoreReturn ret;
81 Model::State state(model.BeginSentenceState());
82 Model::State out;
83
84 AppendTest("looking", 2, -0.484652, true);
85 AppendTest("on", 3, -0.348837, true);
86 AppendTest("a", 4, -0.0155266, true);
87 AppendTest("little", 5, -0.00306122, true);
88 State preserve = state;
89 AppendTest("the", 1, -4.04005, true);
90 AppendTest("biarritz", 1, -1.9889, true);
91 AppendTest("not_found", 1, -2.29666, true);
92 AppendTest("more", 1, -1.20632 - 20.0, true);
93 AppendTest(".", 2, -0.51363, true);
94 AppendTest("</s>", 3, -0.0191651, true);
95 BOOST_CHECK_EQUAL(0, state.length);
96
97 state = preserve;
98 AppendTest("more", 5, -0.00181395, true);
99 BOOST_CHECK_EQUAL(4, state.length);
100 AppendTest("loin", 5, -0.0432557, true);
101 BOOST_CHECK_EQUAL(1, state.length);
102 }
103
Blanks(const M & model)104 template <class M> void Blanks(const M &model) {
105 FullScoreReturn ret;
106 State state(model.NullContextState());
107 State out;
108 AppendTest("also", 1, -1.687872, false);
109 AppendTest("would", 2, -2, true);
110 AppendTest("consider", 3, -3, true);
111 State preserve = state;
112 AppendTest("higher", 4, -4, true);
113 AppendTest("looking", 5, -5, true);
114 BOOST_CHECK_EQUAL(1, state.length);
115
116 state = preserve;
117 // also would consider not_found
118 AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103, true);
119
120 state = model.NullContextState();
121 // higher looking is a blank.
122 AppendTest("higher", 1, -1.509559, false);
123 AppendTest("looking", 2, -1.285941 - 0.30103, false);
124
125 State higher_looking = state;
126
127 BOOST_CHECK_EQUAL(1, state.length);
128 AppendTest("not_found", 1, -1.995635 - 0.4771212, true);
129
130 state = higher_looking;
131 // higher looking consider
132 AppendTest("consider", 1, -1.687872 - 0.4771212, true);
133
134 state = model.NullContextState();
135 AppendTest("would", 1, -1.687872, false);
136 BOOST_CHECK_EQUAL(1, state.length);
137 AppendTest("consider", 2, -1.687872 -0.30103, false);
138 BOOST_CHECK_EQUAL(2, state.length);
139 AppendTest("higher", 3, -1.509559 - 0.30103, false);
140 BOOST_CHECK_EQUAL(3, state.length);
141 AppendTest("looking", 4, -1.285941 - 0.30103, false);
142 }
143
Unknowns(const M & model)144 template <class M> void Unknowns(const M &model) {
145 FullScoreReturn ret;
146 State state(model.NullContextState());
147 State out;
148
149 AppendTest("not_found", 1, -1.995635, false);
150 State preserve = state;
151 AppendTest("not_found2", 2, -15.0, true);
152 AppendTest("not_found3", 2, -15.0 - 2.0, true);
153
154 state = preserve;
155 AppendTest("however", 2, -4, true);
156 AppendTest("not_found3", 3, -6, true);
157 }
158
MinimalState(const M & model)159 template <class M> void MinimalState(const M &model) {
160 FullScoreReturn ret;
161 State state(model.NullContextState());
162 State out;
163
164 AppendTest("baz", 1, -6.535897, true);
165 BOOST_CHECK_EQUAL(0, state.length);
166 state = model.NullContextState();
167 AppendTest("foo", 1, -3.141592, true);
168 BOOST_CHECK_EQUAL(1, state.length);
169 AppendTest("bar", 2, -6.0, true);
170 // Has to include the backoff weight.
171 BOOST_CHECK_EQUAL(1, state.length);
172 AppendTest("bar", 1, -2.718281 + 3.0, true);
173 BOOST_CHECK_EQUAL(1, state.length);
174
175 state = model.NullContextState();
176 AppendTest("to", 1, -1.687872, false);
177 AppendTest("look", 2, -0.2922095, true);
178 BOOST_CHECK_EQUAL(2, state.length);
179 AppendTest("a", 3, -7, true);
180 }
181
ExtendLeftTest(const M & model)182 template <class M> void ExtendLeftTest(const M &model) {
183 State right;
184 FullScoreReturn little(model.FullScore(model.NullContextState(), model.GetVocabulary().Index("little"), right));
185 const float kLittleProb = -1.285941;
186 SLOPPY_CHECK_CLOSE(kLittleProb, little.prob, 0.001);
187 unsigned char next_use;
188 float backoff_out[4];
189
190 FullScoreReturn extend_none(model.ExtendLeft(NULL, NULL, NULL, little.extend_left, 1, NULL, next_use));
191 BOOST_CHECK_EQUAL(0, next_use);
192 BOOST_CHECK_EQUAL(little.extend_left, extend_none.extend_left);
193 SLOPPY_CHECK_CLOSE(little.prob - little.rest, extend_none.prob, 0.001);
194 BOOST_CHECK_EQUAL(1, extend_none.ngram_length);
195
196 const WordIndex a = model.GetVocabulary().Index("a");
197 float backoff_in = 3.14;
198 // a little
199 FullScoreReturn extend_a(model.ExtendLeft(&a, &a + 1, &backoff_in, little.extend_left, 1, backoff_out, next_use));
200 BOOST_CHECK_EQUAL(1, next_use);
201 SLOPPY_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001);
202 SLOPPY_CHECK_CLOSE(-0.09132547 - little.rest, extend_a.prob, 0.001);
203 BOOST_CHECK_EQUAL(2, extend_a.ngram_length);
204 BOOST_CHECK(!extend_a.independent_left);
205
206 const WordIndex on = model.GetVocabulary().Index("on");
207 FullScoreReturn extend_on(model.ExtendLeft(&on, &on + 1, &backoff_in, extend_a.extend_left, 2, backoff_out, next_use));
208 BOOST_CHECK_EQUAL(1, next_use);
209 SLOPPY_CHECK_CLOSE(-0.4771212, backoff_out[0], 0.001);
210 SLOPPY_CHECK_CLOSE(-0.0283603 - (extend_a.rest + little.rest), extend_on.prob, 0.001);
211 BOOST_CHECK_EQUAL(3, extend_on.ngram_length);
212 BOOST_CHECK(!extend_on.independent_left);
213
214 const WordIndex both[2] = {a, on};
215 float backoff_in_arr[4];
216 FullScoreReturn extend_both(model.ExtendLeft(both, both + 2, backoff_in_arr, little.extend_left, 1, backoff_out, next_use));
217 BOOST_CHECK_EQUAL(2, next_use);
218 SLOPPY_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001);
219 SLOPPY_CHECK_CLOSE(-0.4771212, backoff_out[1], 0.001);
220 SLOPPY_CHECK_CLOSE(-0.0283603 - little.rest, extend_both.prob, 0.001);
221 BOOST_CHECK_EQUAL(3, extend_both.ngram_length);
222 BOOST_CHECK(!extend_both.independent_left);
223 BOOST_CHECK_EQUAL(extend_on.extend_left, extend_both.extend_left);
224 }
225
226 #define StatelessTest(word, provide, ngram, score) \
227 ret = model.FullScoreForgotState(indices + num_words - word, indices + num_words - word + provide, indices[num_words - word - 1], state); \
228 SLOPPY_CHECK_CLOSE(score, ret.prob, 0.001); \
229 BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
230 model.GetState(indices + num_words - word, indices + num_words - word + provide, before); \
231 ret = model.FullScore(before, indices[num_words - word - 1], out); \
232 BOOST_CHECK(state == out); \
233 SLOPPY_CHECK_CLOSE(score, ret.prob, 0.001); \
234 BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length);
235
Stateless(const M & model)236 template <class M> void Stateless(const M &model) {
237 const char *words[] = {"<s>", "looking", "on", "a", "little", "the", "biarritz", "not_found", "more", ".", "</s>"};
238 const size_t num_words = sizeof(words) / sizeof(const char*);
239 // Silience "array subscript is above array bounds" when extracting end pointer.
240 WordIndex indices[num_words + 1];
241 for (unsigned int i = 0; i < num_words; ++i) {
242 indices[num_words - 1 - i] = model.GetVocabulary().Index(words[i]);
243 }
244 FullScoreReturn ret;
245 State state, out, before;
246
247 ret = model.FullScoreForgotState(indices + num_words - 1, indices + num_words, indices[num_words - 2], state);
248 SLOPPY_CHECK_CLOSE(-0.484652, ret.prob, 0.001);
249 StatelessTest(1, 1, 2, -0.484652);
250
251 // looking
252 StatelessTest(1, 2, 2, -0.484652);
253 // on
254 AppendTest("on", 3, -0.348837, true);
255 StatelessTest(2, 3, 3, -0.348837);
256 StatelessTest(2, 2, 3, -0.348837);
257 StatelessTest(2, 1, 2, -0.4638903);
258 // a
259 StatelessTest(3, 4, 4, -0.0155266);
260 // little
261 AppendTest("little", 5, -0.00306122, true);
262 StatelessTest(4, 5, 5, -0.00306122);
263 // the
264 AppendTest("the", 1, -4.04005, true);
265 StatelessTest(5, 5, 1, -4.04005);
266 // No context of the.
267 StatelessTest(5, 0, 1, -1.687872);
268 // biarritz
269 StatelessTest(6, 1, 1, -1.9889);
270 // not found
271 StatelessTest(7, 1, 1, -2.29666);
272 StatelessTest(7, 0, 1, -1.995635);
273
274 WordIndex unk[1];
275 unk[0] = 0;
276 model.GetState(unk, unk + 1, state);
277 BOOST_CHECK_EQUAL(1, state.length);
278 BOOST_CHECK_EQUAL(static_cast<WordIndex>(0), state.words[0]);
279 }
280
NoUnkCheck(const M & model)281 template <class M> void NoUnkCheck(const M &model) {
282 WordIndex unk_index = 0;
283 State state;
284
285 FullScoreReturn ret = model.FullScoreForgotState(&unk_index, &unk_index + 1, unk_index, state);
286 SLOPPY_CHECK_CLOSE(-100.0, ret.prob, 0.001);
287 }
288
Everything(const M & m)289 template <class M> void Everything(const M &m) {
290 Starters(m);
291 Continuation(m);
292 Blanks(m);
293 Unknowns(m);
294 MinimalState(m);
295 ExtendLeftTest(m);
296 Stateless(m);
297 }
298
299 class ExpectEnumerateVocab : public EnumerateVocab {
300 public:
ExpectEnumerateVocab()301 ExpectEnumerateVocab() {}
302
Add(WordIndex index,const StringPiece & str)303 void Add(WordIndex index, const StringPiece &str) {
304 BOOST_CHECK_EQUAL(seen.size(), index);
305 seen.push_back(std::string(str.data(), str.length()));
306 }
307
Check(const base::Vocabulary & vocab)308 void Check(const base::Vocabulary &vocab) {
309 BOOST_CHECK_EQUAL(37ULL, seen.size());
310 BOOST_REQUIRE(!seen.empty());
311 BOOST_CHECK_EQUAL("<unk>", seen[0]);
312 for (WordIndex i = 0; i < seen.size(); ++i) {
313 BOOST_CHECK_EQUAL(i, vocab.Index(seen[i]));
314 }
315 }
316
Clear()317 void Clear() {
318 seen.clear();
319 }
320
321 std::vector<std::string> seen;
322 };
323
LoadingTest()324 template <class ModelT> void LoadingTest() {
325 Config config;
326 config.arpa_complain = Config::NONE;
327 config.messages = NULL;
328 config.probing_multiplier = 2.0;
329 {
330 ExpectEnumerateVocab enumerate;
331 config.enumerate_vocab = &enumerate;
332 ModelT m(TestLocation(), config);
333 enumerate.Check(m.GetVocabulary());
334 BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound());
335 Everything(m);
336 }
337 {
338 ExpectEnumerateVocab enumerate;
339 config.enumerate_vocab = &enumerate;
340 ModelT m(TestNoUnkLocation(), config);
341 enumerate.Check(m.GetVocabulary());
342 BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound());
343 NoUnkCheck(m);
344 }
345 }
346
BOOST_AUTO_TEST_CASE(probing)347 BOOST_AUTO_TEST_CASE(probing) {
348 LoadingTest<Model>();
349 }
BOOST_AUTO_TEST_CASE(trie)350 BOOST_AUTO_TEST_CASE(trie) {
351 LoadingTest<TrieModel>();
352 }
BOOST_AUTO_TEST_CASE(quant_trie)353 BOOST_AUTO_TEST_CASE(quant_trie) {
354 LoadingTest<QuantTrieModel>();
355 }
BOOST_AUTO_TEST_CASE(bhiksha_trie)356 BOOST_AUTO_TEST_CASE(bhiksha_trie) {
357 LoadingTest<ArrayTrieModel>();
358 }
BOOST_AUTO_TEST_CASE(quant_bhiksha_trie)359 BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) {
360 LoadingTest<QuantArrayTrieModel>();
361 }
362
BinaryTest(Config::WriteMethod write_method)363 template <class ModelT> void BinaryTest(Config::WriteMethod write_method) {
364 Config config;
365 config.write_mmap = "test.binary";
366 config.messages = NULL;
367 config.write_method = write_method;
368 ExpectEnumerateVocab enumerate;
369 config.enumerate_vocab = &enumerate;
370
371 {
372 ModelT copy_model(TestLocation(), config);
373 enumerate.Check(copy_model.GetVocabulary());
374 enumerate.Clear();
375 Everything(copy_model);
376 }
377
378 config.write_mmap = NULL;
379
380 ModelType type;
381 BOOST_REQUIRE(RecognizeBinary("test.binary", type));
382 BOOST_CHECK_EQUAL(ModelT::kModelType, type);
383
384 {
385 ModelT binary("test.binary", config);
386 enumerate.Check(binary.GetVocabulary());
387 Everything(binary);
388 }
389 unlink("test.binary");
390
391 // Now test without <unk>.
392 config.write_mmap = "test_nounk.binary";
393 config.messages = NULL;
394 enumerate.Clear();
395 {
396 ModelT copy_model(TestNoUnkLocation(), config);
397 enumerate.Check(copy_model.GetVocabulary());
398 enumerate.Clear();
399 NoUnkCheck(copy_model);
400 }
401 config.write_mmap = NULL;
402 {
403 ModelT binary(TestNoUnkLocation(), config);
404 enumerate.Check(binary.GetVocabulary());
405 NoUnkCheck(binary);
406 }
407 unlink("test_nounk.binary");
408 }
409
BinaryTest()410 template <class ModelT> void BinaryTest() {
411 BinaryTest<ModelT>(Config::WRITE_MMAP);
412 BinaryTest<ModelT>(Config::WRITE_AFTER);
413 }
414
BOOST_AUTO_TEST_CASE(write_and_read_probing)415 BOOST_AUTO_TEST_CASE(write_and_read_probing) {
416 BinaryTest<ProbingModel>();
417 }
BOOST_AUTO_TEST_CASE(write_and_read_rest_probing)418 BOOST_AUTO_TEST_CASE(write_and_read_rest_probing) {
419 BinaryTest<RestProbingModel>();
420 }
BOOST_AUTO_TEST_CASE(write_and_read_trie)421 BOOST_AUTO_TEST_CASE(write_and_read_trie) {
422 BinaryTest<TrieModel>();
423 }
BOOST_AUTO_TEST_CASE(write_and_read_quant_trie)424 BOOST_AUTO_TEST_CASE(write_and_read_quant_trie) {
425 BinaryTest<QuantTrieModel>();
426 }
BOOST_AUTO_TEST_CASE(write_and_read_array_trie)427 BOOST_AUTO_TEST_CASE(write_and_read_array_trie) {
428 BinaryTest<ArrayTrieModel>();
429 }
BOOST_AUTO_TEST_CASE(write_and_read_quant_array_trie)430 BOOST_AUTO_TEST_CASE(write_and_read_quant_array_trie) {
431 BinaryTest<QuantArrayTrieModel>();
432 }
433
BOOST_AUTO_TEST_CASE(rest_max)434 BOOST_AUTO_TEST_CASE(rest_max) {
435 Config config;
436 config.arpa_complain = Config::NONE;
437 config.messages = NULL;
438
439 RestProbingModel model(TestLocation(), config);
440 State state, out;
441 FullScoreReturn ret(model.FullScore(model.NullContextState(), model.GetVocabulary().Index("."), state));
442 SLOPPY_CHECK_CLOSE(-0.2705918, ret.rest, 0.001);
443 SLOPPY_CHECK_CLOSE(-0.01916512, model.FullScore(state, model.GetVocabulary().EndSentence(), out).rest, 0.001);
444 }
445
446 } // namespace
447 } // namespace ngram
448 } // namespace lm
449