1 #include "lm/interpolate/pipeline.hh"
2 
3 #include "lm/common/compare.hh"
4 #include "lm/common/print.hh"
5 #include "lm/common/renumber.hh"
6 #include "lm/vocab.hh"
7 #include "lm/interpolate/backoff_reunification.hh"
8 #include "lm/interpolate/interpolate_info.hh"
9 #include "lm/interpolate/merge_probabilities.hh"
10 #include "lm/interpolate/merge_vocab.hh"
11 #include "lm/interpolate/normalize.hh"
12 #include "lm/interpolate/universal_vocab.hh"
13 #include "util/stream/chain.hh"
14 #include "util/stream/count_records.hh"
15 #include "util/stream/io.hh"
16 #include "util/stream/multi_stream.hh"
17 #include "util/stream/sort.hh"
18 #include "util/fixed_array.hh"
19 
20 namespace lm { namespace interpolate { namespace {
21 
22 /* Put the original input files on chains and renumber them */
SetupInputs(std::size_t buffer_size,const UniversalVocab & vocab,util::FixedArray<ModelBuffer> & models,bool exclude_highest,util::FixedArray<util::stream::Chains> & chains,util::FixedArray<util::stream::ChainPositions> & positions)23 void SetupInputs(std::size_t buffer_size, const UniversalVocab &vocab, util::FixedArray<ModelBuffer> &models, bool exclude_highest, util::FixedArray<util::stream::Chains> &chains, util::FixedArray<util::stream::ChainPositions> &positions) {
24   chains.clear();
25   positions.clear();
26   // TODO: much better memory sizing heuristics e.g. not making the chain larger than it will use.
27   util::stream::ChainConfig config(0, 2, buffer_size);
28   for (std::size_t i = 0; i < models.size(); ++i) {
29     chains.push_back(models[i].Order() - exclude_highest);
30     for (std::size_t j = 0; j < models[i].Order() - exclude_highest; ++j) {
31       config.entry_size = sizeof(WordIndex) * (j + 1) + sizeof(float) * 2; // TODO do not include wasteful backoff for highest.
32       chains.back().push_back(config);
33     }
34     if (i == models.size() - 1)
35       chains.back().back().ActivateProgress();
36     models[i].Source(chains.back());
37     for (std::size_t j = 0; j < models[i].Order() - exclude_highest; ++j) {
38       chains[i][j] >> Renumber(vocab.Mapping(i), j + 1);
39     }
40   }
41  for (std::size_t i = 0; i < chains.size(); ++i) {
42     positions.push_back(chains[i]);
43   }
44 }
45 
SinkSort(const util::stream::SortConfig & config,util::stream::Chains & chains,util::stream::Sorts<Compare> & sorts)46 template <class Compare> void SinkSort(const util::stream::SortConfig &config, util::stream::Chains &chains, util::stream::Sorts<Compare> &sorts) {
47   for (std::size_t i = 0; i < chains.size(); ++i) {
48     sorts.push_back(chains[i], config, Compare(i + 1));
49   }
50 }
51 
SourceSort(util::stream::Chains & chains,util::stream::Sorts<Compare> & sorts)52 template <class Compare> void SourceSort(util::stream::Chains &chains, util::stream::Sorts<Compare> &sorts) {
53   // TODO memory management
54   for (std::size_t i = 0; i < sorts.size(); ++i) {
55     sorts[i].Merge(sorts[i].DefaultLazy());
56   }
57   for (std::size_t i = 0; i < sorts.size(); ++i) {
58     sorts[i].Output(chains[i], sorts[i].DefaultLazy());
59   }
60 }
61 
62 } // namespace
63 
Pipeline(util::FixedArray<ModelBuffer> & models,const Config & config,int write_file)64 void Pipeline(util::FixedArray<ModelBuffer> &models, const Config &config, int write_file) {
65   // Setup InterpolateInfo and UniversalVocab.
66   InterpolateInfo info;
67   info.lambdas = config.lambdas;
68   std::vector<WordIndex> vocab_sizes;
69 
70   util::scoped_fd vocab_null(util::MakeTemp(config.sort.temp_prefix));
71   std::size_t max_order = 0;
72   util::FixedArray<int> vocab_files(models.size());
73   for (ModelBuffer *i = models.begin(); i != models.end(); ++i) {
74     info.orders.push_back(i->Order());
75     vocab_sizes.push_back(i->Counts()[0]);
76     vocab_files.push_back(i->VocabFile());
77     max_order = std::max(max_order, i->Order());
78   }
79   util::scoped_ptr<UniversalVocab> vocab(new UniversalVocab(vocab_sizes));
80   {
81     ngram::ImmediateWriteWordsWrapper writer(NULL, vocab_null.get(), 0);
82     MergeVocab(vocab_files, *vocab, writer);
83   }
84 
85   std::cerr << "Merging probabilities." << std::endl;
86   // Pass 1: merge probabilities
87   util::FixedArray<util::stream::Chains> input_chains(models.size());
88   util::FixedArray<util::stream::ChainPositions> models_by_order(models.size());
89   SetupInputs(config.BufferSize(), *vocab, models, false, input_chains, models_by_order);
90 
91   util::stream::Chains merged_probs(max_order);
92   for (std::size_t i = 0; i < max_order; ++i) {
93     merged_probs.push_back(util::stream::ChainConfig(PartialProbGamma::TotalSize(info, i + 1), 2, config.BufferSize())); // TODO: not buffer_size
94   }
95   merged_probs >> MergeProbabilities(info, models_by_order);
96   std::vector<uint64_t> counts(max_order);
97   for (std::size_t i = 0; i < max_order; ++i) {
98     merged_probs[i] >> util::stream::CountRecords(&counts[i]);
99   }
100   for (util::stream::Chains *i = input_chains.begin(); i != input_chains.end(); ++i) {
101     *i >> util::stream::kRecycle;
102   }
103 
104   // Pass 2: normalize.
105   {
106     util::stream::Sorts<ContextOrder> sorts(merged_probs.size());
107     SinkSort(config.sort, merged_probs, sorts);
108     merged_probs.Wait(true);
109     for (util::stream::Chains *i = input_chains.begin(); i != input_chains.end(); ++i) {
110       i->Wait(true);
111     }
112     SourceSort(merged_probs, sorts);
113   }
114 
115   std::cerr << "Normalizing" << std::endl;
116   SetupInputs(config.BufferSize(), *vocab, models, true, input_chains, models_by_order);
117   util::stream::Chains probabilities(max_order), backoffs(max_order - 1);
118   std::size_t block_count = 2;
119   for (std::size_t i = 0; i < max_order; ++i) {
120     // Careful accounting to ensure RewindableStream can fit the entire vocabulary.
121     block_count = std::max<std::size_t>(block_count, 2);
122     // This much needs to fit in RewindableStream.
123     std::size_t fit = NGram<float>::TotalSize(i + 1) * counts[0];
124     // fit / (block_count - 1) rounded up
125     std::size_t min_block = (fit + block_count - 2) / (block_count - 1);
126     std::size_t specify = std::max(config.BufferSize(), min_block * block_count);
127     probabilities.push_back(util::stream::ChainConfig(NGram<float>::TotalSize(i + 1), block_count, specify));
128   }
129   for (std::size_t i = 0; i < max_order - 1; ++i) {
130     backoffs.push_back(util::stream::ChainConfig(sizeof(float), 2, config.BufferSize()));
131   }
132   Normalize(info, models_by_order, merged_probs, probabilities, backoffs);
133   util::FixedArray<util::stream::FileBuffer> backoff_buffers(backoffs.size());
134   for (std::size_t i = 0; i < max_order - 1; ++i) {
135     backoff_buffers.push_back(util::MakeTemp(config.sort.temp_prefix));
136     backoffs[i] >> backoff_buffers.back().Sink() >> util::stream::kRecycle;
137   }
138   for (util::stream::Chains *i = input_chains.begin(); i != input_chains.end(); ++i) {
139     *i >> util::stream::kRecycle;
140   }
141   merged_probs >> util::stream::kRecycle;
142 
143   // Pass 3: backoffs in the right place.
144   {
145     util::stream::Sorts<SuffixOrder> sorts(probabilities.size());
146     SinkSort(config.sort, probabilities, sorts);
147     probabilities.Wait(true);
148     for (util::stream::Chains *i = input_chains.begin(); i != input_chains.end(); ++i) {
149       i->Wait(true);
150     }
151     backoffs.Wait(true);
152     merged_probs.Wait(true);
153     // destroy universal vocab to save RAM.
154     vocab.reset();
155     SourceSort(probabilities, sorts);
156   }
157 
158   std::cerr << "Reunifying backoffs" << std::endl;
159   util::stream::ChainPositions prob_pos(max_order - 1);
160   util::stream::Chains combined(max_order - 1);
161   for (std::size_t i = 0; i < max_order - 1; ++i) {
162     if (i == max_order - 2)
163       backoffs[i].ActivateProgress();
164     backoffs[i].SetProgressTarget(backoff_buffers[i].Size());
165     backoffs[i] >> backoff_buffers[i].Source(true);
166     prob_pos.push_back(probabilities[i].Add());
167     combined.push_back(util::stream::ChainConfig(NGram<ProbBackoff>::TotalSize(i + 1), 2, config.BufferSize()));
168   }
169   util::stream::ChainPositions backoff_pos(backoffs);
170 
171   ReunifyBackoff(prob_pos, backoff_pos, combined);
172 
173   util::stream::ChainPositions output_pos(max_order);
174   for (std::size_t i = 0; i < max_order - 1; ++i) {
175     output_pos.push_back(combined[i].Add());
176   }
177   output_pos.push_back(probabilities.back().Add());
178 
179   probabilities >> util::stream::kRecycle;
180   backoffs >> util::stream::kRecycle;
181   combined >> util::stream::kRecycle;
182 
183   // TODO genericize to ModelBuffer etc.
184   PrintARPA(vocab_null.get(), write_file, counts).Run(output_pos);
185 }
186 
187 }} // namespaces
188