1 #include "lm/interpolate/tune_instances.hh"
2
3 #include "util/file.hh"
4 #include "util/file_stream.hh"
5 #include "util/stream/chain.hh"
6 #include "util/stream/config.hh"
7 #include "util/stream/typed_stream.hh"
8 #include "util/string_piece.hh"
9
10 #define BOOST_TEST_MODULE InstanceTest
11 #include <boost/test/unit_test.hpp>
12
13 #include <vector>
14
15 #include <math.h>
16
17 namespace lm { namespace interpolate { namespace {
18
BOOST_AUTO_TEST_CASE(Toy)19 BOOST_AUTO_TEST_CASE(Toy) {
20 util::scoped_fd test_input(util::MakeTemp("temporary"));
21 util::FileStream(test_input.get()) << "c\n";
22
23 std::string dir("../common/test_data");
24 if (boost::unit_test::framework::master_test_suite().argc == 2) {
25 dir = boost::unit_test::framework::master_test_suite().argv[1];
26 }
27
28 #if BYTE_ORDER == LITTLE_ENDIAN
29 std::string endian = "little";
30 #elif BYTE_ORDER == BIG_ENDIAN
31 std::string endian = "big";
32 #else
33 #error "Unsupported byte order."
34 #endif
35 dir += "/" + endian + "endian/";
36
37 std::vector<StringPiece> model_names;
38 std::string full0 = dir + "toy0";
39 std::string full1 = dir + "toy1";
40 model_names.push_back(full0);
41 model_names.push_back(full1);
42
43 // Tiny buffer sizes.
44 InstancesConfig config;
45 config.model_read_chain_mem = 100;
46 config.extension_write_chain_mem = 100;
47 config.lazy_memory = 100;
48 config.sort.temp_prefix = "temporary";
49 config.sort.buffer_size = 100;
50 config.sort.total_memory = 1024;
51
52 util::SeekOrThrow(test_input.get(), 0);
53
54 Instances inst(test_input.release(), model_names, config);
55
56 BOOST_CHECK_EQUAL(1, inst.BOS());
57 const Matrix &ln_unigrams = inst.LNUnigrams();
58
59 // <unk>=0
60 BOOST_CHECK_CLOSE(-0.90309 * M_LN10, ln_unigrams(0, 0), 0.001);
61 BOOST_CHECK_CLOSE(-1 * M_LN10, ln_unigrams(0, 1), 0.001);
62 // <s>=1 doesn't matter as long as it doesn't cause NaNs.
63 BOOST_CHECK(!isnan(ln_unigrams(1, 0)));
64 BOOST_CHECK(!isnan(ln_unigrams(1, 1)));
65 // a = 2
66 BOOST_CHECK_CLOSE(-0.46943438 * M_LN10, ln_unigrams(2, 0), 0.001);
67 BOOST_CHECK_CLOSE(-0.6146491 * M_LN10, ln_unigrams(2, 1), 0.001);
68 // </s> = 3
69 BOOST_CHECK_CLOSE(-0.5720968 * M_LN10, ln_unigrams(3, 0), 0.001);
70 BOOST_CHECK_CLOSE(-0.6146491 * M_LN10, ln_unigrams(3, 1), 0.001);
71 // c = 4
72 BOOST_CHECK_CLOSE(-0.90309 * M_LN10, ln_unigrams(4, 0), 0.001); // <unk>
73 BOOST_CHECK_CLOSE(-0.7659168 * M_LN10, ln_unigrams(4, 1), 0.001);
74 // too lazy to do b = 5.
75
76 // Two instances:
77 // <s> predicts c
78 // <s> c predicts </s>
79 BOOST_REQUIRE_EQUAL(2, inst.NumInstances());
80 BOOST_CHECK_CLOSE(-0.30103 * M_LN10, inst.LNBackoffs(0)(0), 0.001);
81 BOOST_CHECK_CLOSE(-0.30103 * M_LN10, inst.LNBackoffs(0)(1), 0.001);
82
83
84 // Backoffs of <s> c
85 BOOST_CHECK_CLOSE(0.0, inst.LNBackoffs(1)(0), 0.001);
86 BOOST_CHECK_CLOSE((-0.30103 - 0.30103) * M_LN10, inst.LNBackoffs(1)(1), 0.001);
87
88 util::stream::Chain extensions(util::stream::ChainConfig(inst.ReadExtensionsEntrySize(), 2, 300));
89 inst.ReadExtensions(extensions);
90 util::stream::TypedStream<Extension> stream(extensions.Add());
91 extensions >> util::stream::kRecycle;
92
93 // The extensions are (in order of instance, vocab id, and model as they should be sorted):
94 // <s> a from both models 0 and 1 (so two instances)
95 // <s> c from model 1
96 // <s> b from model 0
97 // c </s> from model 1
98 // Magic probabilities come from querying the models directly.
99
100 // <s> a from model 0
101 BOOST_REQUIRE(stream);
102 BOOST_CHECK_EQUAL(0, stream->instance);
103 BOOST_CHECK_EQUAL(2 /* a */, stream->word);
104 BOOST_CHECK_EQUAL(0, stream->model);
105 BOOST_CHECK_CLOSE(-0.37712017 * M_LN10, stream->ln_prob, 0.001);
106
107 // <s> a from model 1
108 BOOST_REQUIRE(++stream);
109 BOOST_CHECK_EQUAL(0, stream->instance);
110 BOOST_CHECK_EQUAL(2 /* a */, stream->word);
111 BOOST_CHECK_EQUAL(1, stream->model);
112 BOOST_CHECK_CLOSE(-0.4301247 * M_LN10, stream->ln_prob, 0.001);
113
114 // <s> c from model 1
115 BOOST_REQUIRE(++stream);
116 BOOST_CHECK_EQUAL(0, stream->instance);
117 BOOST_CHECK_EQUAL(4 /* c */, stream->word);
118 BOOST_CHECK_EQUAL(1, stream->model);
119 BOOST_CHECK_CLOSE(-0.4740302 * M_LN10, stream->ln_prob, 0.001);
120
121 // <s> b from model 0
122 BOOST_REQUIRE(++stream);
123 BOOST_CHECK_EQUAL(0, stream->instance);
124 BOOST_CHECK_EQUAL(5 /* b */, stream->word);
125 BOOST_CHECK_EQUAL(0, stream->model);
126 BOOST_CHECK_CLOSE(-0.41574955 * M_LN10, stream->ln_prob, 0.001);
127
128 // c </s> from model 1
129 BOOST_REQUIRE(++stream);
130 BOOST_CHECK_EQUAL(1, stream->instance);
131 BOOST_CHECK_EQUAL(3 /* </s> */, stream->word);
132 BOOST_CHECK_EQUAL(1, stream->model);
133 BOOST_CHECK_CLOSE(-0.09113217 * M_LN10, stream->ln_prob, 0.001);
134
135 BOOST_CHECK(!++stream);
136 }
137
138 }}} // namespaces
139