1 #include "lm/model.hh"
2 
3 #include <cstdlib>
4 #include <cstring>
5 
6 #define BOOST_TEST_MODULE ModelTest
7 #include <boost/test/unit_test.hpp>
8 #include <boost/test/floating_point_comparison.hpp>
9 
10 // Apparently some Boost versions use templates and are pretty strict about types matching.
11 #define SLOPPY_CHECK_CLOSE(ref, value, tol) BOOST_CHECK_CLOSE(static_cast<double>(ref), static_cast<double>(value), static_cast<double>(tol));
12 
13 namespace lm {
14 namespace ngram {
15 
operator <<(std::ostream & o,const State & state)16 std::ostream &operator<<(std::ostream &o, const State &state) {
17   o << "State length " << static_cast<unsigned int>(state.length) << ':';
18   for (const WordIndex *i = state.words; i < state.words + state.length; ++i) {
19     o << ' ' << *i;
20   }
21   return o;
22 }
23 
24 namespace {
25 
26 // Stupid bjam reverses the command line arguments randomly.
TestLocation()27 const char *TestLocation() {
28   if (boost::unit_test::framework::master_test_suite().argc < 3) {
29     return "test.arpa";
30   }
31   char **argv = boost::unit_test::framework::master_test_suite().argv;
32   return argv[strstr(argv[1], "nounk") ? 2 : 1];
33 }
TestNoUnkLocation()34 const char *TestNoUnkLocation() {
35   if (boost::unit_test::framework::master_test_suite().argc < 3) {
36     return "test_nounk.arpa";
37   }
38   char **argv = boost::unit_test::framework::master_test_suite().argv;
39   return argv[strstr(argv[1], "nounk") ? 1 : 2];
40 }
41 
GetState(const Model & model,const char * word,const State & in)42 template <class Model> State GetState(const Model &model, const char *word, const State &in) {
43   WordIndex context[in.length + 1];
44   context[0] = model.GetVocabulary().Index(word);
45   std::copy(in.words, in.words + in.length, context + 1);
46   State ret;
47   model.GetState(context, context + in.length + 1, ret);
48   return ret;
49 }
50 
51 #define StartTest(word, ngram, score, indep_left) \
52   ret = model.FullScore( \
53       state, \
54       model.GetVocabulary().Index(word), \
55       out);\
56   SLOPPY_CHECK_CLOSE(score, ret.prob, 0.001); \
57   BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
58   BOOST_CHECK_GE(std::min<unsigned char>(ngram, 5 - 1), out.length); \
59   BOOST_CHECK_EQUAL(indep_left, ret.independent_left); \
60   BOOST_CHECK_EQUAL(out, GetState(model, word, state));
61 
62 #define AppendTest(word, ngram, score, indep_left) \
63   StartTest(word, ngram, score, indep_left) \
64   state = out;
65 
Starters(const M & model)66 template <class M> void Starters(const M &model) {
67   FullScoreReturn ret;
68   Model::State state(model.BeginSentenceState());
69   Model::State out;
70 
71   StartTest("looking", 2, -0.4846522, true);
72 
73   // , probability plus <s> backoff
74   StartTest(",", 1, -1.383514 + -0.4149733, true);
75   // <unk> probability plus <s> backoff
76   StartTest("this_is_not_found", 1, -1.995635 + -0.4149733, true);
77 }
78 
Continuation(const M & model)79 template <class M> void Continuation(const M &model) {
80   FullScoreReturn ret;
81   Model::State state(model.BeginSentenceState());
82   Model::State out;
83 
84   AppendTest("looking", 2, -0.484652, true);
85   AppendTest("on", 3, -0.348837, true);
86   AppendTest("a", 4, -0.0155266, true);
87   AppendTest("little", 5, -0.00306122, true);
88   State preserve = state;
89   AppendTest("the", 1, -4.04005, true);
90   AppendTest("biarritz", 1, -1.9889, true);
91   AppendTest("not_found", 1, -2.29666, true);
92   AppendTest("more", 1, -1.20632 - 20.0, true);
93   AppendTest(".", 2, -0.51363, true);
94   AppendTest("</s>", 3, -0.0191651, true);
95   BOOST_CHECK_EQUAL(0, state.length);
96 
97   state = preserve;
98   AppendTest("more", 5, -0.00181395, true);
99   BOOST_CHECK_EQUAL(4, state.length);
100   AppendTest("loin", 5, -0.0432557, true);
101   BOOST_CHECK_EQUAL(1, state.length);
102 }
103 
Blanks(const M & model)104 template <class M> void Blanks(const M &model) {
105   FullScoreReturn ret;
106   State state(model.NullContextState());
107   State out;
108   AppendTest("also", 1, -1.687872, false);
109   AppendTest("would", 2, -2, true);
110   AppendTest("consider", 3, -3, true);
111   State preserve = state;
112   AppendTest("higher", 4, -4, true);
113   AppendTest("looking", 5, -5, true);
114   BOOST_CHECK_EQUAL(1, state.length);
115 
116   state = preserve;
117   // also would consider not_found
118   AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103, true);
119 
120   state = model.NullContextState();
121   // higher looking is a blank.
122   AppendTest("higher", 1, -1.509559, false);
123   AppendTest("looking", 2, -1.285941 - 0.30103, false);
124 
125   State higher_looking = state;
126 
127   BOOST_CHECK_EQUAL(1, state.length);
128   AppendTest("not_found", 1, -1.995635 - 0.4771212, true);
129 
130   state = higher_looking;
131   // higher looking consider
132   AppendTest("consider", 1, -1.687872 - 0.4771212, true);
133 
134   state = model.NullContextState();
135   AppendTest("would", 1, -1.687872, false);
136   BOOST_CHECK_EQUAL(1, state.length);
137   AppendTest("consider", 2, -1.687872 -0.30103, false);
138   BOOST_CHECK_EQUAL(2, state.length);
139   AppendTest("higher", 3, -1.509559 - 0.30103, false);
140   BOOST_CHECK_EQUAL(3, state.length);
141   AppendTest("looking", 4, -1.285941 - 0.30103, false);
142 }
143 
Unknowns(const M & model)144 template <class M> void Unknowns(const M &model) {
145   FullScoreReturn ret;
146   State state(model.NullContextState());
147   State out;
148 
149   AppendTest("not_found", 1, -1.995635, false);
150   State preserve = state;
151   AppendTest("not_found2", 2, -15.0, true);
152   AppendTest("not_found3", 2, -15.0 - 2.0, true);
153 
154   state = preserve;
155   AppendTest("however", 2, -4, true);
156   AppendTest("not_found3", 3, -6, true);
157 }
158 
MinimalState(const M & model)159 template <class M> void MinimalState(const M &model) {
160   FullScoreReturn ret;
161   State state(model.NullContextState());
162   State out;
163 
164   AppendTest("baz", 1, -6.535897, true);
165   BOOST_CHECK_EQUAL(0, state.length);
166   state = model.NullContextState();
167   AppendTest("foo", 1, -3.141592, true);
168   BOOST_CHECK_EQUAL(1, state.length);
169   AppendTest("bar", 2, -6.0, true);
170   // Has to include the backoff weight.
171   BOOST_CHECK_EQUAL(1, state.length);
172   AppendTest("bar", 1, -2.718281 + 3.0, true);
173   BOOST_CHECK_EQUAL(1, state.length);
174 
175   state = model.NullContextState();
176   AppendTest("to", 1, -1.687872, false);
177   AppendTest("look", 2, -0.2922095, true);
178   BOOST_CHECK_EQUAL(2, state.length);
179   AppendTest("a", 3, -7, true);
180 }
181 
ExtendLeftTest(const M & model)182 template <class M> void ExtendLeftTest(const M &model) {
183   State right;
184   FullScoreReturn little(model.FullScore(model.NullContextState(), model.GetVocabulary().Index("little"), right));
185   const float kLittleProb = -1.285941;
186   SLOPPY_CHECK_CLOSE(kLittleProb, little.prob, 0.001);
187   unsigned char next_use;
188   float backoff_out[4];
189 
190   FullScoreReturn extend_none(model.ExtendLeft(NULL, NULL, NULL, little.extend_left, 1, NULL, next_use));
191   BOOST_CHECK_EQUAL(0, next_use);
192   BOOST_CHECK_EQUAL(little.extend_left, extend_none.extend_left);
193   SLOPPY_CHECK_CLOSE(little.prob - little.rest, extend_none.prob, 0.001);
194   BOOST_CHECK_EQUAL(1, extend_none.ngram_length);
195 
196   const WordIndex a = model.GetVocabulary().Index("a");
197   float backoff_in = 3.14;
198   // a little
199   FullScoreReturn extend_a(model.ExtendLeft(&a, &a + 1, &backoff_in, little.extend_left, 1, backoff_out, next_use));
200   BOOST_CHECK_EQUAL(1, next_use);
201   SLOPPY_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001);
202   SLOPPY_CHECK_CLOSE(-0.09132547 - little.rest, extend_a.prob, 0.001);
203   BOOST_CHECK_EQUAL(2, extend_a.ngram_length);
204   BOOST_CHECK(!extend_a.independent_left);
205 
206   const WordIndex on = model.GetVocabulary().Index("on");
207   FullScoreReturn extend_on(model.ExtendLeft(&on, &on + 1, &backoff_in, extend_a.extend_left, 2, backoff_out, next_use));
208   BOOST_CHECK_EQUAL(1, next_use);
209   SLOPPY_CHECK_CLOSE(-0.4771212, backoff_out[0], 0.001);
210   SLOPPY_CHECK_CLOSE(-0.0283603 - (extend_a.rest + little.rest), extend_on.prob, 0.001);
211   BOOST_CHECK_EQUAL(3, extend_on.ngram_length);
212   BOOST_CHECK(!extend_on.independent_left);
213 
214   const WordIndex both[2] = {a, on};
215   float backoff_in_arr[4];
216   FullScoreReturn extend_both(model.ExtendLeft(both, both + 2, backoff_in_arr, little.extend_left, 1, backoff_out, next_use));
217   BOOST_CHECK_EQUAL(2, next_use);
218   SLOPPY_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001);
219   SLOPPY_CHECK_CLOSE(-0.4771212, backoff_out[1], 0.001);
220   SLOPPY_CHECK_CLOSE(-0.0283603 - little.rest, extend_both.prob, 0.001);
221   BOOST_CHECK_EQUAL(3, extend_both.ngram_length);
222   BOOST_CHECK(!extend_both.independent_left);
223   BOOST_CHECK_EQUAL(extend_on.extend_left, extend_both.extend_left);
224 }
225 
226 #define StatelessTest(word, provide, ngram, score) \
227   ret = model.FullScoreForgotState(indices + num_words - word, indices + num_words - word + provide, indices[num_words - word - 1], state); \
228   SLOPPY_CHECK_CLOSE(score, ret.prob, 0.001); \
229   BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
230   model.GetState(indices + num_words - word, indices + num_words - word + provide, before); \
231   ret = model.FullScore(before, indices[num_words - word - 1], out); \
232   BOOST_CHECK(state == out); \
233   SLOPPY_CHECK_CLOSE(score, ret.prob, 0.001); \
234   BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length);
235 
Stateless(const M & model)236 template <class M> void Stateless(const M &model) {
237   const char *words[] = {"<s>", "looking", "on", "a", "little", "the", "biarritz", "not_found", "more", ".", "</s>"};
238   const size_t num_words = sizeof(words) / sizeof(const char*);
239   // Silience "array subscript is above array bounds" when extracting end pointer.
240   WordIndex indices[num_words + 1];
241   for (unsigned int i = 0; i < num_words; ++i) {
242     indices[num_words - 1 - i] = model.GetVocabulary().Index(words[i]);
243   }
244   FullScoreReturn ret;
245   State state, out, before;
246 
247   ret = model.FullScoreForgotState(indices + num_words - 1, indices + num_words, indices[num_words - 2], state);
248   SLOPPY_CHECK_CLOSE(-0.484652, ret.prob, 0.001);
249   StatelessTest(1, 1, 2, -0.484652);
250 
251   // looking
252   StatelessTest(1, 2, 2, -0.484652);
253   // on
254   AppendTest("on", 3, -0.348837, true);
255   StatelessTest(2, 3, 3, -0.348837);
256   StatelessTest(2, 2, 3, -0.348837);
257   StatelessTest(2, 1, 2, -0.4638903);
258   // a
259   StatelessTest(3, 4, 4, -0.0155266);
260   // little
261   AppendTest("little", 5, -0.00306122, true);
262   StatelessTest(4, 5, 5, -0.00306122);
263   // the
264   AppendTest("the", 1, -4.04005, true);
265   StatelessTest(5, 5, 1, -4.04005);
266   // No context of the.
267   StatelessTest(5, 0, 1, -1.687872);
268   // biarritz
269   StatelessTest(6, 1, 1, -1.9889);
270   // not found
271   StatelessTest(7, 1, 1, -2.29666);
272   StatelessTest(7, 0, 1, -1.995635);
273 
274   WordIndex unk[1];
275   unk[0] = 0;
276   model.GetState(unk, unk + 1, state);
277   BOOST_CHECK_EQUAL(1, state.length);
278   BOOST_CHECK_EQUAL(static_cast<WordIndex>(0), state.words[0]);
279 }
280 
NoUnkCheck(const M & model)281 template <class M> void NoUnkCheck(const M &model) {
282   WordIndex unk_index = 0;
283   State state;
284 
285   FullScoreReturn ret = model.FullScoreForgotState(&unk_index, &unk_index + 1, unk_index, state);
286   SLOPPY_CHECK_CLOSE(-100.0, ret.prob, 0.001);
287 }
288 
Everything(const M & m)289 template <class M> void Everything(const M &m) {
290   Starters(m);
291   Continuation(m);
292   Blanks(m);
293   Unknowns(m);
294   MinimalState(m);
295   ExtendLeftTest(m);
296   Stateless(m);
297 }
298 
299 class ExpectEnumerateVocab : public EnumerateVocab {
300   public:
ExpectEnumerateVocab()301     ExpectEnumerateVocab() {}
302 
Add(WordIndex index,const StringPiece & str)303     void Add(WordIndex index, const StringPiece &str) {
304       BOOST_CHECK_EQUAL(seen.size(), index);
305       seen.push_back(std::string(str.data(), str.length()));
306     }
307 
Check(const base::Vocabulary & vocab)308     void Check(const base::Vocabulary &vocab) {
309       BOOST_CHECK_EQUAL(37ULL, seen.size());
310       BOOST_REQUIRE(!seen.empty());
311       BOOST_CHECK_EQUAL("<unk>", seen[0]);
312       for (WordIndex i = 0; i < seen.size(); ++i) {
313         BOOST_CHECK_EQUAL(i, vocab.Index(seen[i]));
314       }
315     }
316 
Clear()317     void Clear() {
318       seen.clear();
319     }
320 
321     std::vector<std::string> seen;
322 };
323 
LoadingTest()324 template <class ModelT> void LoadingTest() {
325   Config config;
326   config.arpa_complain = Config::NONE;
327   config.messages = NULL;
328   config.probing_multiplier = 2.0;
329   {
330     ExpectEnumerateVocab enumerate;
331     config.enumerate_vocab = &enumerate;
332     ModelT m(TestLocation(), config);
333     enumerate.Check(m.GetVocabulary());
334     BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound());
335     Everything(m);
336   }
337   {
338     ExpectEnumerateVocab enumerate;
339     config.enumerate_vocab = &enumerate;
340     ModelT m(TestNoUnkLocation(), config);
341     enumerate.Check(m.GetVocabulary());
342     BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound());
343     NoUnkCheck(m);
344   }
345 }
346 
BOOST_AUTO_TEST_CASE(probing)347 BOOST_AUTO_TEST_CASE(probing) {
348   LoadingTest<Model>();
349 }
BOOST_AUTO_TEST_CASE(trie)350 BOOST_AUTO_TEST_CASE(trie) {
351   LoadingTest<TrieModel>();
352 }
BOOST_AUTO_TEST_CASE(quant_trie)353 BOOST_AUTO_TEST_CASE(quant_trie) {
354   LoadingTest<QuantTrieModel>();
355 }
BOOST_AUTO_TEST_CASE(bhiksha_trie)356 BOOST_AUTO_TEST_CASE(bhiksha_trie) {
357   LoadingTest<ArrayTrieModel>();
358 }
BOOST_AUTO_TEST_CASE(quant_bhiksha_trie)359 BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) {
360   LoadingTest<QuantArrayTrieModel>();
361 }
362 
BinaryTest(Config::WriteMethod write_method)363 template <class ModelT> void BinaryTest(Config::WriteMethod write_method) {
364   Config config;
365   config.write_mmap = "test.binary";
366   config.messages = NULL;
367   config.write_method = write_method;
368   ExpectEnumerateVocab enumerate;
369   config.enumerate_vocab = &enumerate;
370 
371   {
372     ModelT copy_model(TestLocation(), config);
373     enumerate.Check(copy_model.GetVocabulary());
374     enumerate.Clear();
375     Everything(copy_model);
376   }
377 
378   config.write_mmap = NULL;
379 
380   ModelType type;
381   BOOST_REQUIRE(RecognizeBinary("test.binary", type));
382   BOOST_CHECK_EQUAL(ModelT::kModelType, type);
383 
384   {
385     ModelT binary("test.binary", config);
386     enumerate.Check(binary.GetVocabulary());
387     Everything(binary);
388   }
389   unlink("test.binary");
390 
391   // Now test without <unk>.
392   config.write_mmap = "test_nounk.binary";
393   config.messages = NULL;
394   enumerate.Clear();
395   {
396     ModelT copy_model(TestNoUnkLocation(), config);
397     enumerate.Check(copy_model.GetVocabulary());
398     enumerate.Clear();
399     NoUnkCheck(copy_model);
400   }
401   config.write_mmap = NULL;
402   {
403     ModelT binary(TestNoUnkLocation(), config);
404     enumerate.Check(binary.GetVocabulary());
405     NoUnkCheck(binary);
406   }
407   unlink("test_nounk.binary");
408 }
409 
BinaryTest()410 template <class ModelT> void BinaryTest() {
411   BinaryTest<ModelT>(Config::WRITE_MMAP);
412   BinaryTest<ModelT>(Config::WRITE_AFTER);
413 }
414 
BOOST_AUTO_TEST_CASE(write_and_read_probing)415 BOOST_AUTO_TEST_CASE(write_and_read_probing) {
416   BinaryTest<ProbingModel>();
417 }
BOOST_AUTO_TEST_CASE(write_and_read_rest_probing)418 BOOST_AUTO_TEST_CASE(write_and_read_rest_probing) {
419   BinaryTest<RestProbingModel>();
420 }
BOOST_AUTO_TEST_CASE(write_and_read_trie)421 BOOST_AUTO_TEST_CASE(write_and_read_trie) {
422   BinaryTest<TrieModel>();
423 }
BOOST_AUTO_TEST_CASE(write_and_read_quant_trie)424 BOOST_AUTO_TEST_CASE(write_and_read_quant_trie) {
425   BinaryTest<QuantTrieModel>();
426 }
BOOST_AUTO_TEST_CASE(write_and_read_array_trie)427 BOOST_AUTO_TEST_CASE(write_and_read_array_trie) {
428   BinaryTest<ArrayTrieModel>();
429 }
BOOST_AUTO_TEST_CASE(write_and_read_quant_array_trie)430 BOOST_AUTO_TEST_CASE(write_and_read_quant_array_trie) {
431   BinaryTest<QuantArrayTrieModel>();
432 }
433 
BOOST_AUTO_TEST_CASE(rest_max)434 BOOST_AUTO_TEST_CASE(rest_max) {
435   Config config;
436   config.arpa_complain = Config::NONE;
437   config.messages = NULL;
438 
439   RestProbingModel model(TestLocation(), config);
440   State state, out;
441   FullScoreReturn ret(model.FullScore(model.NullContextState(), model.GetVocabulary().Index("."), state));
442   SLOPPY_CHECK_CLOSE(-0.2705918, ret.rest, 0.001);
443   SLOPPY_CHECK_CLOSE(-0.01916512, model.FullScore(state, model.GetVocabulary().EndSentence(), out).rest, 0.001);
444 }
445 
446 } // namespace
447 } // namespace ngram
448 } // namespace lm
449