1 /* This is where the trie is built.  It's on-disk.  */
2 #include "lm/search_trie.hh"
3 
4 #include "lm/bhiksha.hh"
5 #include "lm/binary_format.hh"
6 #include "lm/blank.hh"
7 #include "lm/lm_exception.hh"
8 #include "lm/max_order.hh"
9 #include "lm/quantize.hh"
10 #include "lm/trie.hh"
11 #include "lm/trie_sort.hh"
12 #include "lm/vocab.hh"
13 #include "lm/weights.hh"
14 #include "lm/word_index.hh"
15 #include "util/ersatz_progress.hh"
16 #include "util/mmap.hh"
17 #include "util/proxy_iterator.hh"
18 #include "util/scoped.hh"
19 #include "util/sized_iterator.hh"
20 
21 #include <algorithm>
22 #include <cstring>
23 #include <cstdio>
24 #include <cstdlib>
25 #include <queue>
26 #include <limits>
27 #include <numeric>
28 #include <vector>
29 
30 #if defined(_WIN32) || defined(_WIN64)
31 #include <windows.h>
32 #endif
33 
34 namespace lm {
35 namespace ngram {
36 namespace trie {
37 namespace {
38 
ReadOrThrow(FILE * from,void * data,size_t size)39 void ReadOrThrow(FILE *from, void *data, size_t size) {
40   UTIL_THROW_IF(1 != std::fread(data, size, 1, from), util::ErrnoException, "Short read");
41 }
42 
Compare(unsigned char order,const void * first_void,const void * second_void)43 int Compare(unsigned char order, const void *first_void, const void *second_void) {
44   const WordIndex *first = reinterpret_cast<const WordIndex*>(first_void), *second = reinterpret_cast<const WordIndex*>(second_void);
45   const WordIndex *end = first + order;
46   for (; first != end; ++first, ++second) {
47     if (*first < *second) return -1;
48     if (*first > *second) return 1;
49   }
50   return 0;
51 }
52 
53 struct ProbPointer {
54   unsigned char array;
55   uint64_t index;
56 };
57 
58 // Array of n-grams and float indices.
59 class BackoffMessages {
60   public:
Init(std::size_t entry_size)61     void Init(std::size_t entry_size) {
62       current_ = NULL;
63       allocated_ = NULL;
64       entry_size_ = entry_size;
65     }
66 
Add(const WordIndex * to,ProbPointer index)67     void Add(const WordIndex *to, ProbPointer index) {
68       while (current_ + entry_size_ > allocated_) {
69         std::size_t allocated_size = allocated_ - (uint8_t*)backing_.get();
70         Resize(std::max<std::size_t>(allocated_size * 2, entry_size_));
71       }
72       memcpy(current_, to, entry_size_ - sizeof(ProbPointer));
73       *reinterpret_cast<ProbPointer*>(current_ + entry_size_ - sizeof(ProbPointer)) = index;
74       current_ += entry_size_;
75     }
76 
Apply(float * const * const base,FILE * unigrams)77     void Apply(float *const *const base, FILE *unigrams) {
78       FinishedAdding();
79       if (current_ == allocated_) return;
80       rewind(unigrams);
81       ProbBackoff weights;
82       WordIndex unigram = 0;
83       ReadOrThrow(unigrams, &weights, sizeof(weights));
84       for (; current_ != allocated_; current_ += entry_size_) {
85         const WordIndex &cur_word = *reinterpret_cast<const WordIndex*>(current_);
86         for (; unigram < cur_word; ++unigram) {
87           ReadOrThrow(unigrams, &weights, sizeof(weights));
88         }
89         if (!HasExtension(weights.backoff)) {
90           weights.backoff = kExtensionBackoff;
91           UTIL_THROW_IF(fseek(unigrams, -sizeof(weights), SEEK_CUR), util::ErrnoException, "Seeking backwards to denote unigram extension failed.");
92           util::WriteOrThrow(unigrams, &weights, sizeof(weights));
93         }
94         const ProbPointer &write_to = *reinterpret_cast<const ProbPointer*>(current_ + sizeof(WordIndex));
95         base[write_to.array][write_to.index] += weights.backoff;
96       }
97       backing_.reset();
98     }
99 
Apply(float * const * const base,RecordReader & reader)100     void Apply(float *const *const base, RecordReader &reader) {
101       FinishedAdding();
102       if (current_ == allocated_) return;
103       // We'll also use the same buffer to record messages to blanks that they extend.
104       WordIndex *extend_out = reinterpret_cast<WordIndex*>(current_);
105       const unsigned char order = (entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex);
106       for (reader.Rewind(); reader && (current_ != allocated_); ) {
107         switch (Compare(order, reader.Data(), current_)) {
108           case -1:
109             ++reader;
110             break;
111           case 1:
112             // Message but nobody to receive it.  Write it down at the beginning of the buffer so we can inform this blank that it extends.
113             for (const WordIndex *w = reinterpret_cast<const WordIndex *>(current_); w != reinterpret_cast<const WordIndex *>(current_) + order; ++w, ++extend_out) *extend_out = *w;
114             current_ += entry_size_;
115             break;
116           case 0:
117             float &backoff = reinterpret_cast<ProbBackoff*>((uint8_t*)reader.Data() + order * sizeof(WordIndex))->backoff;
118             if (!HasExtension(backoff)) {
119               backoff = kExtensionBackoff;
120               reader.Overwrite(&backoff, sizeof(float));
121             } else {
122               const ProbPointer &write_to = *reinterpret_cast<const ProbPointer*>(current_ + entry_size_ - sizeof(ProbPointer));
123               base[write_to.array][write_to.index] += backoff;
124             }
125             current_ += entry_size_;
126             break;
127         }
128       }
129       // Now this is a list of blanks that extend right.
130       entry_size_ = sizeof(WordIndex) * order;
131       Resize(sizeof(WordIndex) * (extend_out - (const WordIndex*)backing_.get()));
132       current_ = (uint8_t*)backing_.get();
133     }
134 
135     // Call after Apply
Extends(unsigned char order,const WordIndex * words)136     bool Extends(unsigned char order, const WordIndex *words) {
137       if (current_ == allocated_) return false;
138       assert(order * sizeof(WordIndex) == entry_size_);
139       while (true) {
140         switch(Compare(order, words, current_)) {
141           case 1:
142             current_ += entry_size_;
143             if (current_ == allocated_) return false;
144             break;
145           case -1:
146             return false;
147           case 0:
148             return true;
149         }
150       }
151     }
152 
153   private:
FinishedAdding()154     void FinishedAdding() {
155       Resize(current_ - (uint8_t*)backing_.get());
156       // Sort requests in same order as files.
157       util::SizedSort(backing_.get(), current_, entry_size_, EntryCompare((entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex)));
158       current_ = (uint8_t*)backing_.get();
159     }
160 
Resize(std::size_t to)161     void Resize(std::size_t to) {
162       std::size_t current = current_ - (uint8_t*)backing_.get();
163       backing_.call_realloc(to);
164       current_ = (uint8_t*)backing_.get() + current;
165       allocated_ = (uint8_t*)backing_.get() + to;
166     }
167 
168     util::scoped_malloc backing_;
169 
170     uint8_t *current_, *allocated_;
171 
172     std::size_t entry_size_;
173 };
174 
175 const float kBadProb = std::numeric_limits<float>::infinity();
176 
177 class SRISucks {
178   public:
SRISucks()179     SRISucks() {
180       for (BackoffMessages *i = messages_; i != messages_ + KENLM_MAX_ORDER - 1; ++i)
181         i->Init(sizeof(ProbPointer) + sizeof(WordIndex) * (i - messages_ + 1));
182     }
183 
Send(unsigned char begin,unsigned char order,const WordIndex * to,float prob_basis)184     void Send(unsigned char begin, unsigned char order, const WordIndex *to, float prob_basis) {
185       assert(prob_basis != kBadProb);
186       ProbPointer pointer;
187       pointer.array = order - 1;
188       pointer.index = values_[order - 1].size();
189       for (unsigned char i = begin; i < order; ++i) {
190         messages_[i - 1].Add(to, pointer);
191       }
192       values_[order - 1].push_back(prob_basis);
193     }
194 
ObtainBackoffs(unsigned char total_order,FILE * unigram_file,RecordReader * reader)195     void ObtainBackoffs(unsigned char total_order, FILE *unigram_file, RecordReader *reader) {
196       for (unsigned char i = 0; i < KENLM_MAX_ORDER - 1; ++i) {
197         it_[i] = values_[i].empty() ? NULL : &*values_[i].begin();
198       }
199       messages_[0].Apply(it_, unigram_file);
200       BackoffMessages *messages = messages_ + 1;
201       const RecordReader *end = reader + total_order - 2 /* exclude unigrams and longest order */;
202       for (; reader != end; ++messages, ++reader) {
203         messages->Apply(it_, *reader);
204       }
205     }
206 
GetBlank(unsigned char total_order,unsigned char order,const WordIndex * indices)207     ProbBackoff GetBlank(unsigned char total_order, unsigned char order, const WordIndex *indices) {
208       assert(order > 1);
209       ProbBackoff ret;
210       ret.prob = *(it_[order - 1]++);
211       ret.backoff = ((order != total_order - 1) && messages_[order - 1].Extends(order, indices)) ? kExtensionBackoff : kNoExtensionBackoff;
212       return ret;
213     }
214 
Values(unsigned char order) const215     const std::vector<float> &Values(unsigned char order) const {
216       return values_[order - 1];
217     }
218 
219   private:
220     // This used to be one array.  Then I needed to separate it by order for quantization to work.
221     std::vector<float> values_[KENLM_MAX_ORDER - 1];
222     BackoffMessages messages_[KENLM_MAX_ORDER - 1];
223 
224     float *it_[KENLM_MAX_ORDER - 1];
225 };
226 
227 class FindBlanks {
228   public:
FindBlanks(unsigned char order,const ProbBackoff * unigrams,SRISucks & messages)229     FindBlanks(unsigned char order, const ProbBackoff *unigrams, SRISucks &messages)
230       : counts_(order), unigrams_(unigrams), sri_(messages) {}
231 
UnigramProb(WordIndex index) const232     float UnigramProb(WordIndex index) const {
233       return unigrams_[index].prob;
234     }
235 
Unigram(WordIndex)236     void Unigram(WordIndex /*index*/) {
237       ++counts_[0];
238     }
239 
MiddleBlank(const unsigned char order,const WordIndex * indices,unsigned char lower,float prob_basis)240     void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char lower, float prob_basis) {
241       sri_.Send(lower, order, indices + 1, prob_basis);
242       ++counts_[order - 1];
243     }
244 
Middle(const unsigned char order,const void *)245     void Middle(const unsigned char order, const void * /*data*/) {
246       ++counts_[order - 1];
247     }
248 
Longest(const void *)249     void Longest(const void * /*data*/) {
250       ++counts_.back();
251     }
252 
Counts() const253     const std::vector<uint64_t> &Counts() const {
254       return counts_;
255     }
256 
257   private:
258     std::vector<uint64_t> counts_;
259 
260     const ProbBackoff *unigrams_;
261 
262     SRISucks &sri_;
263 };
264 
265 // Phase to actually write n-grams to the trie.
266 template <class Quant, class Bhiksha> class WriteEntries {
267   public:
WriteEntries(RecordReader * contexts,const Quant & quant,UnigramValue * unigrams,BitPackedMiddle<Bhiksha> * middle,BitPackedLongest & longest,unsigned char order,SRISucks & sri)268     WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) :
269       contexts_(contexts),
270       quant_(quant),
271       unigrams_(unigrams),
272       middle_(middle),
273       longest_(longest),
274       bigram_pack_((order == 2) ? static_cast<BitPacked&>(longest_) : static_cast<BitPacked&>(*middle_)),
275       order_(order),
276       sri_(sri) {}
277 
UnigramProb(WordIndex index) const278     float UnigramProb(WordIndex index) const { return unigrams_[index].weights.prob; }
279 
Unigram(WordIndex word)280     void Unigram(WordIndex word) {
281       unigrams_[word].next = bigram_pack_.InsertIndex();
282     }
283 
MiddleBlank(const unsigned char order,const WordIndex * indices,unsigned char,float)284     void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char /*lower*/, float /*prob_base*/) {
285       ProbBackoff weights = sri_.GetBlank(order_, order, indices);
286       typename Quant::MiddlePointer(quant_, order - 2, middle_[order - 2].Insert(indices[order - 1])).Write(weights.prob, weights.backoff);
287     }
288 
Middle(const unsigned char order,const void * data)289     void Middle(const unsigned char order, const void *data) {
290       RecordReader &context = contexts_[order - 1];
291       const WordIndex *words = reinterpret_cast<const WordIndex*>(data);
292       ProbBackoff weights = *reinterpret_cast<const ProbBackoff*>(words + order);
293       if (context && !memcmp(data, context.Data(), sizeof(WordIndex) * order)) {
294         SetExtension(weights.backoff);
295         ++context;
296       }
297       typename Quant::MiddlePointer(quant_, order - 2, middle_[order - 2].Insert(words[order - 1])).Write(weights.prob, weights.backoff);
298     }
299 
Longest(const void * data)300     void Longest(const void *data) {
301       const WordIndex *words = reinterpret_cast<const WordIndex*>(data);
302       typename Quant::LongestPointer(quant_, longest_.Insert(words[order_ - 1])).Write(reinterpret_cast<const Prob*>(words + order_)->prob);
303     }
304 
305   private:
306     RecordReader *contexts_;
307     const Quant &quant_;
308     UnigramValue *const unigrams_;
309     BitPackedMiddle<Bhiksha> *const middle_;
310     BitPackedLongest &longest_;
311     BitPacked &bigram_pack_;
312     const unsigned char order_;
313     SRISucks &sri_;
314 };
315 
316 struct Gram {
Gramlm::ngram::trie::__anon19ae4fd80111::Gram317   Gram(const WordIndex *in_begin, unsigned char order) : begin(in_begin), end(in_begin + order) {}
318 
319   const WordIndex *begin, *end;
320 
321   // For queue, this is the direction we want.
operator <lm::ngram::trie::__anon19ae4fd80111::Gram322   bool operator<(const Gram &other) const {
323     return std::lexicographical_compare(other.begin, other.end, begin, end);
324   }
325 };
326 
327 template <class Doing> class BlankManager {
328   public:
BlankManager(unsigned char total_order,Doing & doing)329     BlankManager(unsigned char total_order, Doing &doing) : total_order_(total_order), been_length_(0), doing_(doing) {
330       for (float *i = basis_; i != basis_ + KENLM_MAX_ORDER - 1; ++i) *i = kBadProb;
331     }
332 
Visit(const WordIndex * to,unsigned char length,float prob)333     void Visit(const WordIndex *to, unsigned char length, float prob) {
334       basis_[length - 1] = prob;
335       unsigned char overlap = std::min<unsigned char>(length - 1, been_length_);
336       const WordIndex *cur;
337       WordIndex *pre;
338       for (cur = to, pre = been_; cur != to + overlap; ++cur, ++pre) {
339         if (*pre != *cur) break;
340       }
341       if (cur == to + length - 1) {
342         *pre = *cur;
343         been_length_ = length;
344         return;
345       }
346       // There are blanks to insert starting with order blank.
347       unsigned char blank = cur - to + 1;
348       UTIL_THROW_IF(blank == 1, FormatLoadException, "Missing a unigram that appears as context.");
349       const float *lower_basis;
350       for (lower_basis = basis_ + blank - 2; *lower_basis == kBadProb; --lower_basis) {}
351       unsigned char based_on = lower_basis - basis_ + 1;
352       for (; cur != to + length - 1; ++blank, ++cur, ++pre) {
353         assert(*lower_basis != kBadProb);
354         doing_.MiddleBlank(blank, to, based_on, *lower_basis);
355         *pre = *cur;
356         // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram.
357         basis_[blank - 1] = kBadProb;
358       }
359       *pre = *cur;
360       been_length_ = length;
361     }
362 
363   private:
364     const unsigned char total_order_;
365 
366     WordIndex been_[KENLM_MAX_ORDER];
367     unsigned char been_length_;
368 
369     float basis_[KENLM_MAX_ORDER];
370 
371     Doing &doing_;
372 };
373 
RecursiveInsert(const unsigned char total_order,const WordIndex unigram_count,RecordReader * input,std::ostream * progress_out,const char * message,Doing & doing)374 template <class Doing> void RecursiveInsert(const unsigned char total_order, const WordIndex unigram_count, RecordReader *input, std::ostream *progress_out, const char *message, Doing &doing) {
375   util::ErsatzProgress progress(unigram_count + 1, progress_out, message);
376   WordIndex unigram = 0;
377   std::priority_queue<Gram> grams;
378   if (unigram_count) grams.push(Gram(&unigram, 1));
379   for (unsigned char i = 2; i <= total_order; ++i) {
380     if (input[i-2]) grams.push(Gram(reinterpret_cast<const WordIndex*>(input[i-2].Data()), i));
381   }
382 
383   BlankManager<Doing> blank(total_order, doing);
384 
385   while (!grams.empty()) {
386     Gram top = grams.top();
387     grams.pop();
388     unsigned char order = top.end - top.begin;
389     if (order == 1) {
390       blank.Visit(&unigram, 1, doing.UnigramProb(unigram));
391       doing.Unigram(unigram);
392       progress.Set(unigram);
393       if (++unigram < unigram_count) grams.push(top);
394     } else {
395       if (order == total_order) {
396         blank.Visit(top.begin, order, reinterpret_cast<const Prob*>(top.end)->prob);
397         doing.Longest(top.begin);
398       } else {
399         blank.Visit(top.begin, order, reinterpret_cast<const ProbBackoff*>(top.end)->prob);
400         doing.Middle(order, top.begin);
401       }
402       RecordReader &reader = input[order - 2];
403       if (++reader) grams.push(top);
404     }
405   }
406 }
407 
SanityCheckCounts(const std::vector<uint64_t> & initial,const std::vector<uint64_t> & fixed)408 void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<uint64_t> &fixed) {
409   if (fixed[0] != initial[0]) UTIL_THROW(util::Exception, "Unigram count should be constant but initial is " << initial[0] << " and recounted is " << fixed[0]);
410   if (fixed.back() != initial.back()) UTIL_THROW(util::Exception, "Longest count should be constant but it changed from " << initial.back() << " to " << fixed.back());
411   for (unsigned char i = 0; i < initial.size(); ++i) {
412     if (fixed[i] < initial[i]) UTIL_THROW(util::Exception, "Counts came out lower than expected.  This shouldn't happen");
413   }
414 }
415 
TrainQuantizer(uint8_t order,uint64_t count,const std::vector<float> & additional,RecordReader & reader,util::ErsatzProgress & progress,Quant & quant)416 template <class Quant> void TrainQuantizer(uint8_t order, uint64_t count, const std::vector<float> &additional, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) {
417   std::vector<float> probs(additional), backoffs;
418   probs.reserve(count + additional.size());
419   backoffs.reserve(count);
420   for (reader.Rewind(); reader; ++reader) {
421     const ProbBackoff &weights = *reinterpret_cast<const ProbBackoff*>(reinterpret_cast<const uint8_t*>(reader.Data()) + sizeof(WordIndex) * order);
422     probs.push_back(weights.prob);
423     if (weights.backoff != 0.0) backoffs.push_back(weights.backoff);
424     ++progress;
425   }
426   quant.Train(order, probs, backoffs);
427 }
428 
TrainProbQuantizer(uint8_t order,uint64_t count,RecordReader & reader,util::ErsatzProgress & progress,Quant & quant)429 template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) {
430   std::vector<float> probs, backoffs;
431   probs.reserve(count);
432   for (reader.Rewind(); reader; ++reader) {
433     const Prob &weights = *reinterpret_cast<const Prob*>(reinterpret_cast<const uint8_t*>(reader.Data()) + sizeof(WordIndex) * order);
434     probs.push_back(weights.prob);
435     ++progress;
436   }
437   quant.TrainProb(order, probs);
438 }
439 
PopulateUnigramWeights(FILE * file,WordIndex unigram_count,RecordReader & contexts,UnigramValue * unigrams)440 void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) {
441   // Fill unigram probabilities.
442   try {
443     rewind(file);
444     for (WordIndex i = 0; i < unigram_count; ++i) {
445       ReadOrThrow(file, &unigrams[i].weights, sizeof(ProbBackoff));
446       if (contexts && *reinterpret_cast<const WordIndex*>(contexts.Data()) == i) {
447         SetExtension(unigrams[i].weights.backoff);
448         ++contexts;
449       }
450     }
451   } catch (util::Exception &e) {
452     e << " while re-reading unigram probabilities";
453     throw;
454   }
455 }
456 
457 } // namespace
458 
BuildTrie(SortedFiles & files,std::vector<uint64_t> & counts,const Config & config,TrieSearch<Quant,Bhiksha> & out,Quant & quant,SortedVocabulary & vocab,BinaryFormat & backing)459 template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing) {
460   RecordReader inputs[KENLM_MAX_ORDER - 1];
461   RecordReader contexts[KENLM_MAX_ORDER - 1];
462 
463   for (unsigned char i = 2; i <= counts.size(); ++i) {
464     inputs[i-2].Init(files.Full(i), i * sizeof(WordIndex) + (i == counts.size() ? sizeof(Prob) : sizeof(ProbBackoff)));
465     contexts[i-2].Init(files.Context(i), (i-1) * sizeof(WordIndex));
466   }
467 
468   SRISucks sri;
469   std::vector<uint64_t> fixed_counts;
470   util::scoped_FILE unigram_file;
471   util::scoped_fd unigram_fd(files.StealUnigram());
472   {
473     util::scoped_memory unigrams;
474     MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams);
475     FindBlanks finder(counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri);
476     RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Identifying n-grams omitted by SRI", finder);
477     fixed_counts = finder.Counts();
478   }
479   unigram_file.reset(util::FDOpenOrThrow(unigram_fd));
480   for (const RecordReader *i = inputs; i != inputs + counts.size() - 2; ++i) {
481     if (*i) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading");
482   }
483   SanityCheckCounts(counts, fixed_counts);
484   counts = fixed_counts;
485 
486   sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs);
487 
488   void *vocab_relocate;
489   void *search_base = backing.GrowForSearch(TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), vocab.UnkCountChangePadding(), vocab_relocate);
490   vocab.Relocate(vocab_relocate);
491   out.SetupMemory(reinterpret_cast<uint8_t*>(search_base), fixed_counts, config);
492 
493   for (unsigned char i = 2; i <= counts.size(); ++i) {
494     inputs[i-2].Rewind();
495   }
496   if (Quant::kTrain) {
497     util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0),
498                                   config.ProgressMessages(), "Quantizing");
499     for (unsigned char i = 2; i < counts.size(); ++i) {
500       TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant);
501     }
502     TrainProbQuantizer(counts.size(), counts.back(), inputs[counts.size() - 2], progress, quant);
503     quant.FinishedLoading(config);
504   }
505 
506   UnigramValue *unigrams = out.unigram_.Raw();
507   PopulateUnigramWeights(unigram_file.get(), counts[0], contexts[0], unigrams);
508   unigram_file.reset();
509 
510   for (unsigned char i = 2; i <= counts.size(); ++i) {
511     inputs[i-2].Rewind();
512   }
513   // Fill entries except unigram probabilities.
514   {
515     WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri);
516     RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer);
517     // Write the last unigram entry, which is the end pointer for the bigrams.
518     writer.Unigram(counts[0]);
519   }
520 
521   // Do not disable this error message or else too little state will be returned.  Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
522   for (unsigned char order = 2; order <= counts.size(); ++order) {
523     const RecordReader &context = contexts[order - 2];
524     if (context) {
525       FormatLoadException e;
526       e << "A " << static_cast<unsigned int>(order) << "-gram has context";
527       const WordIndex *ctx = reinterpret_cast<const WordIndex*>(context.Data());
528       for (const WordIndex *i = ctx; i != ctx + order - 1; ++i) {
529         e << ' ' << *i;
530       }
531       e << " so this context must appear in the model as a " << static_cast<unsigned int>(order - 1) << "-gram but it does not";
532       throw e;
533     }
534   }
535 
536   /* Set ending offsets so the last entry will be sized properly */
537   // Last entry for unigrams was already set.
538   if (out.middle_begin_ != out.middle_end_) {
539     for (typename TrieSearch<Quant, Bhiksha>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) {
540       i->FinishedLoading((i+1)->InsertIndex(), config);
541     }
542     (out.middle_end_ - 1)->FinishedLoading(out.longest_.InsertIndex(), config);
543   }
544 }
545 
SetupMemory(uint8_t * start,const std::vector<uint64_t> & counts,const Config & config)546 template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
547   quant_.SetupMemory(start, counts.size(), config);
548   start += Quant::Size(counts.size(), config);
549   unigram_.Init(start);
550   start += Unigram::Size(counts[0]);
551   FreeMiddles();
552   middle_begin_ = static_cast<Middle*>(malloc(sizeof(Middle) * (counts.size() - 2)));
553   middle_end_ = middle_begin_ + (counts.size() - 2);
554   std::vector<uint8_t*> middle_starts(counts.size() - 2);
555   for (unsigned char i = 2; i < counts.size(); ++i) {
556     middle_starts[i-2] = start;
557     start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i], config);
558   }
559   // Crazy backwards thing so we initialize using pointers to ones that have already been initialized
560   for (unsigned char i = counts.size() - 1; i >= 2; --i) {
561     // use "placement new" syntax to initalize Middle in an already-allocated memory location
562     new (middle_begin_ + i - 2) Middle(
563         middle_starts[i-2],
564         quant_.MiddleBits(config),
565         counts[i-1],
566         counts[0],
567         counts[i],
568         (i == counts.size() - 1) ? static_cast<const BitPacked&>(longest_) : static_cast<const BitPacked &>(middle_begin_[i-1]),
569         config);
570   }
571   longest_.Init(start, quant_.LongestBits(config), counts[0]);
572   return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
573 }
574 
InitializeFromARPA(const char * file,util::FilePiece & f,std::vector<uint64_t> & counts,const Config & config,SortedVocabulary & vocab,BinaryFormat & backing)575 template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing) {
576   std::string temporary_prefix;
577   if (!config.temporary_directory_prefix.empty()) {
578     temporary_prefix = config.temporary_directory_prefix;
579   } else if (config.write_mmap) {
580     temporary_prefix = config.write_mmap;
581   } else {
582     temporary_prefix = file;
583   }
584   // At least 1MB sorting memory.
585   SortedFiles sorted(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_prefix, vocab);
586 
587   BuildTrie(sorted, counts, config, *this, quant_, vocab, backing);
588 }
589 
590 template class TrieSearch<DontQuantize, DontBhiksha>;
591 template class TrieSearch<DontQuantize, ArrayBhiksha>;
592 template class TrieSearch<SeparatelyQuantize, DontBhiksha>;
593 template class TrieSearch<SeparatelyQuantize, ArrayBhiksha>;
594 
595 } // namespace trie
596 } // namespace ngram
597 } // namespace lm
598