1 #include "lm/interpolate/tune_instances.hh"
2 
3 #include "util/file.hh"
4 #include "util/file_stream.hh"
5 #include "util/stream/chain.hh"
6 #include "util/stream/config.hh"
7 #include "util/stream/typed_stream.hh"
8 #include "util/string_piece.hh"
9 
10 #define BOOST_TEST_MODULE InstanceTest
11 #include <boost/test/unit_test.hpp>
12 
13 #include <vector>
14 
15 #include <math.h>
16 
17 namespace lm { namespace interpolate { namespace {
18 
BOOST_AUTO_TEST_CASE(Toy)19 BOOST_AUTO_TEST_CASE(Toy) {
20   util::scoped_fd test_input(util::MakeTemp("temporary"));
21   util::FileStream(test_input.get()) << "c\n";
22 
23   std::string dir("../common/test_data");
24   if (boost::unit_test::framework::master_test_suite().argc == 2) {
25     dir = boost::unit_test::framework::master_test_suite().argv[1];
26   }
27 
28 #if BYTE_ORDER == LITTLE_ENDIAN
29   std::string endian = "little";
30 #elif BYTE_ORDER == BIG_ENDIAN
31   std::string endian = "big";
32 #else
33 #error "Unsupported byte order."
34 #endif
35   dir += "/" + endian + "endian/";
36 
37   std::vector<StringPiece> model_names;
38   std::string full0 = dir + "toy0";
39   std::string full1 = dir + "toy1";
40   model_names.push_back(full0);
41   model_names.push_back(full1);
42 
43   // Tiny buffer sizes.
44   InstancesConfig config;
45   config.model_read_chain_mem = 100;
46   config.extension_write_chain_mem = 100;
47   config.lazy_memory = 100;
48   config.sort.temp_prefix = "temporary";
49   config.sort.buffer_size = 100;
50   config.sort.total_memory = 1024;
51 
52   util::SeekOrThrow(test_input.get(), 0);
53 
54   Instances inst(test_input.release(), model_names, config);
55 
56   BOOST_CHECK_EQUAL(1, inst.BOS());
57   const Matrix &ln_unigrams = inst.LNUnigrams();
58 
59   // <unk>=0
60   BOOST_CHECK_CLOSE(-0.90309 * M_LN10, ln_unigrams(0, 0), 0.001);
61   BOOST_CHECK_CLOSE(-1 * M_LN10, ln_unigrams(0, 1), 0.001);
62   // <s>=1 doesn't matter as long as it doesn't cause NaNs.
63   BOOST_CHECK(!isnan(ln_unigrams(1, 0)));
64   BOOST_CHECK(!isnan(ln_unigrams(1, 1)));
65   // a = 2
66   BOOST_CHECK_CLOSE(-0.46943438 * M_LN10, ln_unigrams(2, 0), 0.001);
67   BOOST_CHECK_CLOSE(-0.6146491 * M_LN10, ln_unigrams(2, 1), 0.001);
68   // </s> = 3
69   BOOST_CHECK_CLOSE(-0.5720968 * M_LN10, ln_unigrams(3, 0), 0.001);
70   BOOST_CHECK_CLOSE(-0.6146491 * M_LN10, ln_unigrams(3, 1), 0.001);
71   // c = 4
72   BOOST_CHECK_CLOSE(-0.90309 * M_LN10, ln_unigrams(4, 0), 0.001); // <unk>
73   BOOST_CHECK_CLOSE(-0.7659168 * M_LN10, ln_unigrams(4, 1), 0.001);
74   // too lazy to do b = 5.
75 
76   // Two instances:
77   // <s> predicts c
78   // <s> c predicts </s>
79   BOOST_REQUIRE_EQUAL(2, inst.NumInstances());
80   BOOST_CHECK_CLOSE(-0.30103 * M_LN10, inst.LNBackoffs(0)(0), 0.001);
81   BOOST_CHECK_CLOSE(-0.30103 * M_LN10, inst.LNBackoffs(0)(1), 0.001);
82 
83 
84   // Backoffs of <s> c
85   BOOST_CHECK_CLOSE(0.0, inst.LNBackoffs(1)(0), 0.001);
86   BOOST_CHECK_CLOSE((-0.30103 - 0.30103) * M_LN10, inst.LNBackoffs(1)(1), 0.001);
87 
88   util::stream::Chain extensions(util::stream::ChainConfig(inst.ReadExtensionsEntrySize(), 2, 300));
89   inst.ReadExtensions(extensions);
90   util::stream::TypedStream<Extension> stream(extensions.Add());
91   extensions >> util::stream::kRecycle;
92 
93   // The extensions are (in order of instance, vocab id, and model as they should be sorted):
94   // <s> a from both models 0 and 1 (so two instances)
95   // <s> c from model 1
96   // <s> b from model 0
97   // c </s> from model 1
98   // Magic probabilities come from querying the models directly.
99 
100   // <s> a from model 0
101   BOOST_REQUIRE(stream);
102   BOOST_CHECK_EQUAL(0, stream->instance);
103   BOOST_CHECK_EQUAL(2 /* a */, stream->word);
104   BOOST_CHECK_EQUAL(0, stream->model);
105   BOOST_CHECK_CLOSE(-0.37712017 * M_LN10, stream->ln_prob, 0.001);
106 
107   // <s> a from model 1
108   BOOST_REQUIRE(++stream);
109   BOOST_CHECK_EQUAL(0, stream->instance);
110   BOOST_CHECK_EQUAL(2 /* a */, stream->word);
111   BOOST_CHECK_EQUAL(1, stream->model);
112   BOOST_CHECK_CLOSE(-0.4301247 * M_LN10, stream->ln_prob, 0.001);
113 
114   // <s> c from model 1
115   BOOST_REQUIRE(++stream);
116   BOOST_CHECK_EQUAL(0, stream->instance);
117   BOOST_CHECK_EQUAL(4 /* c */, stream->word);
118   BOOST_CHECK_EQUAL(1, stream->model);
119   BOOST_CHECK_CLOSE(-0.4740302 * M_LN10, stream->ln_prob, 0.001);
120 
121   // <s> b from model 0
122   BOOST_REQUIRE(++stream);
123   BOOST_CHECK_EQUAL(0, stream->instance);
124   BOOST_CHECK_EQUAL(5 /* b */, stream->word);
125   BOOST_CHECK_EQUAL(0, stream->model);
126   BOOST_CHECK_CLOSE(-0.41574955 * M_LN10, stream->ln_prob, 0.001);
127 
128   // c </s> from model 1
129   BOOST_REQUIRE(++stream);
130   BOOST_CHECK_EQUAL(1, stream->instance);
131   BOOST_CHECK_EQUAL(3 /* </s> */, stream->word);
132   BOOST_CHECK_EQUAL(1, stream->model);
133   BOOST_CHECK_CLOSE(-0.09113217 * M_LN10, stream->ln_prob, 0.001);
134 
135   BOOST_CHECK(!++stream);
136 }
137 
138 }}} // namespaces
139