1 ///////////////////////////////////////////////////////////////////////
2 // File:        recodebeam.cpp
3 // Description: Beam search to decode from the re-encoded CJK as a sequence of
4 //              smaller numbers in place of a single large code.
5 // Author:      Ray Smith
6 //
7 // (C) Copyright 2015, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
17 //
18 ///////////////////////////////////////////////////////////////////////
19 
20 #include "recodebeam.h"
21 
22 #include "networkio.h"
23 #include "pageres.h"
24 #include "unicharcompress.h"
25 
26 #include <algorithm> // for std::reverse
27 #include <deque>
28 #include <map>
29 #include <set>
30 #include <tuple>
31 #include <unordered_set>
32 #include <vector>
33 
34 namespace tesseract {
35 
36 // The beam width at each code position.
37 const int RecodeBeamSearch::kBeamWidths[RecodedCharID::kMaxCodeLen + 1] = {
38     5, 10, 16, 16, 16, 16, 16, 16, 16, 16,
39 };
40 
41 static const char *kNodeContNames[] = {"Anything", "OnlyDup", "NoDup"};
42 
43 // Prints debug details of the node.
Print(int null_char,const UNICHARSET & unicharset,int depth) const44 void RecodeNode::Print(int null_char, const UNICHARSET &unicharset,
45                        int depth) const {
46   if (code == null_char) {
47     tprintf("null_char");
48   } else {
49     tprintf("label=%d, uid=%d=%s", code, unichar_id,
50             unicharset.debug_str(unichar_id).c_str());
51   }
52   tprintf(" score=%g, c=%g,%s%s%s perm=%d, hash=%" PRIx64, score, certainty,
53           start_of_dawg ? " DawgStart" : "", start_of_word ? " Start" : "",
54           end_of_word ? " End" : "", permuter, code_hash);
55   if (depth > 0 && prev != nullptr) {
56     tprintf(" prev:");
57     prev->Print(null_char, unicharset, depth - 1);
58   } else {
59     tprintf("\n");
60   }
61 }
62 
63 // Borrows the pointer, which is expected to survive until *this is deleted.
RecodeBeamSearch(const UnicharCompress & recoder,int null_char,bool simple_text,Dict * dict)64 RecodeBeamSearch::RecodeBeamSearch(const UnicharCompress &recoder,
65                                    int null_char, bool simple_text, Dict *dict)
66     : recoder_(recoder),
67       beam_size_(0),
68       top_code_(-1),
69       second_code_(-1),
70       dict_(dict),
71       space_delimited_(true),
72       is_simple_text_(simple_text),
73       null_char_(null_char) {
74   if (dict_ != nullptr && !dict_->IsSpaceDelimitedLang()) {
75     space_delimited_ = false;
76   }
77 }
78 
~RecodeBeamSearch()79 RecodeBeamSearch::~RecodeBeamSearch() {
80   for (auto data : beam_) {
81     delete data;
82   }
83   for (auto data : secondary_beam_) {
84     delete data;
85   }
86 }
87 
88 // Decodes the set of network outputs, storing the lattice internally.
Decode(const NetworkIO & output,double dict_ratio,double cert_offset,double worst_dict_cert,const UNICHARSET * charset,int lstm_choice_mode)89 void RecodeBeamSearch::Decode(const NetworkIO &output, double dict_ratio,
90                               double cert_offset, double worst_dict_cert,
91                               const UNICHARSET *charset, int lstm_choice_mode) {
92   beam_size_ = 0;
93   int width = output.Width();
94   if (lstm_choice_mode) {
95     timesteps.clear();
96   }
97   for (int t = 0; t < width; ++t) {
98     ComputeTopN(output.f(t), output.NumFeatures(), kBeamWidths[0]);
99     DecodeStep(output.f(t), t, dict_ratio, cert_offset, worst_dict_cert,
100                charset);
101     if (lstm_choice_mode) {
102       SaveMostCertainChoices(output.f(t), output.NumFeatures(), charset, t);
103     }
104   }
105 }
Decode(const GENERIC_2D_ARRAY<float> & output,double dict_ratio,double cert_offset,double worst_dict_cert,const UNICHARSET * charset)106 void RecodeBeamSearch::Decode(const GENERIC_2D_ARRAY<float> &output,
107                               double dict_ratio, double cert_offset,
108                               double worst_dict_cert,
109                               const UNICHARSET *charset) {
110   beam_size_ = 0;
111   int width = output.dim1();
112   for (int t = 0; t < width; ++t) {
113     ComputeTopN(output[t], output.dim2(), kBeamWidths[0]);
114     DecodeStep(output[t], t, dict_ratio, cert_offset, worst_dict_cert, charset);
115   }
116 }
117 
DecodeSecondaryBeams(const NetworkIO & output,double dict_ratio,double cert_offset,double worst_dict_cert,const UNICHARSET * charset,int lstm_choice_mode)118 void RecodeBeamSearch::DecodeSecondaryBeams(
119     const NetworkIO &output, double dict_ratio, double cert_offset,
120     double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode) {
121   for (auto data : secondary_beam_) {
122     delete data;
123   }
124   secondary_beam_.clear();
125   if (character_boundaries_.size() < 2) {
126     return;
127   }
128   int width = output.Width();
129   unsigned bucketNumber = 0;
130   for (int t = 0; t < width; ++t) {
131     while ((bucketNumber + 1) < character_boundaries_.size() &&
132            t >= character_boundaries_[bucketNumber + 1]) {
133       ++bucketNumber;
134     }
135     ComputeSecTopN(&(excludedUnichars)[bucketNumber], output.f(t),
136                    output.NumFeatures(), kBeamWidths[0]);
137     DecodeSecondaryStep(output.f(t), t, dict_ratio, cert_offset,
138                         worst_dict_cert, charset);
139   }
140 }
141 
SaveMostCertainChoices(const float * outputs,int num_outputs,const UNICHARSET * charset,int xCoord)142 void RecodeBeamSearch::SaveMostCertainChoices(const float *outputs,
143                                               int num_outputs,
144                                               const UNICHARSET *charset,
145                                               int xCoord) {
146   std::vector<std::pair<const char *, float>> choices;
147   for (int i = 0; i < num_outputs; ++i) {
148     if (outputs[i] >= 0.01f) {
149       const char *character;
150       if (i + 2 >= num_outputs) {
151         character = "";
152       } else if (i > 0) {
153         character = charset->id_to_unichar_ext(i + 2);
154       } else {
155         character = charset->id_to_unichar_ext(i);
156       }
157       size_t pos = 0;
158       // order the possible choices within one timestep
159       // beginning with the most likely
160       while (choices.size() > pos && choices[pos].second > outputs[i]) {
161         pos++;
162       }
163       choices.insert(choices.begin() + pos,
164                      std::pair<const char *, float>(character, outputs[i]));
165     }
166   }
167   timesteps.push_back(choices);
168 }
169 
segmentTimestepsByCharacters()170 void RecodeBeamSearch::segmentTimestepsByCharacters() {
171   for (unsigned i = 1; i < character_boundaries_.size(); ++i) {
172     std::vector<std::vector<std::pair<const char *, float>>> segment;
173     for (int j = character_boundaries_[i - 1]; j < character_boundaries_[i];
174          ++j) {
175       segment.push_back(timesteps[j]);
176     }
177     segmentedTimesteps.push_back(segment);
178   }
179 }
180 std::vector<std::vector<std::pair<const char *, float>>>
combineSegmentedTimesteps(std::vector<std::vector<std::vector<std::pair<const char *,float>>>> * segmentedTimesteps)181 RecodeBeamSearch::combineSegmentedTimesteps(
182     std::vector<std::vector<std::vector<std::pair<const char *, float>>>>
183         *segmentedTimesteps) {
184   std::vector<std::vector<std::pair<const char *, float>>> combined_timesteps;
185   for (auto &segmentedTimestep : *segmentedTimesteps) {
186     for (auto &j : segmentedTimestep) {
187       combined_timesteps.push_back(j);
188     }
189   }
190   return combined_timesteps;
191 }
192 
calculateCharBoundaries(std::vector<int> * starts,std::vector<int> * ends,std::vector<int> * char_bounds_,int maxWidth)193 void RecodeBeamSearch::calculateCharBoundaries(std::vector<int> *starts,
194                                                std::vector<int> *ends,
195                                                std::vector<int> *char_bounds_,
196                                                int maxWidth) {
197   char_bounds_->push_back(0);
198   for (unsigned i = 0; i < ends->size(); ++i) {
199     int middle = ((*starts)[i + 1] - (*ends)[i]) / 2;
200     char_bounds_->push_back((*ends)[i] + middle);
201   }
202   char_bounds_->pop_back();
203   char_bounds_->push_back(maxWidth);
204 }
205 
206 // Returns the best path as labels/scores/xcoords similar to simple CTC.
ExtractBestPathAsLabels(std::vector<int> * labels,std::vector<int> * xcoords) const207 void RecodeBeamSearch::ExtractBestPathAsLabels(
208     std::vector<int> *labels, std::vector<int> *xcoords) const {
209   labels->clear();
210   xcoords->clear();
211   std::vector<const RecodeNode *> best_nodes;
212   ExtractBestPaths(&best_nodes, nullptr);
213   // Now just run CTC on the best nodes.
214   int t = 0;
215   int width = best_nodes.size();
216   while (t < width) {
217     int label = best_nodes[t]->code;
218     if (label != null_char_) {
219       labels->push_back(label);
220       xcoords->push_back(t);
221     }
222     while (++t < width && !is_simple_text_ && best_nodes[t]->code == label) {
223     }
224   }
225   xcoords->push_back(width);
226 }
227 
228 // Returns the best path as unichar-ids/certs/ratings/xcoords skipping
229 // duplicates, nulls and intermediate parts.
ExtractBestPathAsUnicharIds(bool debug,const UNICHARSET * unicharset,std::vector<int> * unichar_ids,std::vector<float> * certs,std::vector<float> * ratings,std::vector<int> * xcoords) const230 void RecodeBeamSearch::ExtractBestPathAsUnicharIds(
231     bool debug, const UNICHARSET *unicharset, std::vector<int> *unichar_ids,
232     std::vector<float> *certs, std::vector<float> *ratings,
233     std::vector<int> *xcoords) const {
234   std::vector<const RecodeNode *> best_nodes;
235   ExtractBestPaths(&best_nodes, nullptr);
236   ExtractPathAsUnicharIds(best_nodes, unichar_ids, certs, ratings, xcoords);
237   if (debug) {
238     DebugPath(unicharset, best_nodes);
239     DebugUnicharPath(unicharset, best_nodes, *unichar_ids, *certs, *ratings,
240                      *xcoords);
241   }
242 }
243 
244 // Returns the best path as a set of WERD_RES.
ExtractBestPathAsWords(const TBOX & line_box,float scale_factor,bool debug,const UNICHARSET * unicharset,PointerVector<WERD_RES> * words,int lstm_choice_mode)245 void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX &line_box,
246                                               float scale_factor, bool debug,
247                                               const UNICHARSET *unicharset,
248                                               PointerVector<WERD_RES> *words,
249                                               int lstm_choice_mode) {
250   words->truncate(0);
251   std::vector<int> unichar_ids;
252   std::vector<float> certs;
253   std::vector<float> ratings;
254   std::vector<int> xcoords;
255   std::vector<const RecodeNode *> best_nodes;
256   std::vector<const RecodeNode *> second_nodes;
257   character_boundaries_.clear();
258   ExtractBestPaths(&best_nodes, &second_nodes);
259   if (debug) {
260     DebugPath(unicharset, best_nodes);
261     ExtractPathAsUnicharIds(second_nodes, &unichar_ids, &certs, &ratings,
262                             &xcoords);
263     tprintf("\nSecond choice path:\n");
264     DebugUnicharPath(unicharset, second_nodes, unichar_ids, certs, ratings,
265                      xcoords);
266   }
267   // If lstm choice mode is required in granularity level 2, it stores the x
268   // Coordinates of every chosen character, to match the alternative choices to
269   // it.
270   ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings, &xcoords,
271                           &character_boundaries_);
272   int num_ids = unichar_ids.size();
273   if (debug) {
274     DebugUnicharPath(unicharset, best_nodes, unichar_ids, certs, ratings,
275                      xcoords);
276   }
277   // Convert labels to unichar-ids.
278   int word_end = 0;
279   float prev_space_cert = 0.0f;
280   for (int word_start = 0; word_start < num_ids; word_start = word_end) {
281     for (word_end = word_start + 1; word_end < num_ids; ++word_end) {
282       // A word is terminated when a space character or start_of_word flag is
283       // hit. We also want to force a separate word for every non
284       // space-delimited character when not in a dictionary context.
285       if (unichar_ids[word_end] == UNICHAR_SPACE) {
286         break;
287       }
288       int index = xcoords[word_end];
289       if (best_nodes[index]->start_of_word) {
290         break;
291       }
292       if (best_nodes[index]->permuter == TOP_CHOICE_PERM &&
293           (!unicharset->IsSpaceDelimited(unichar_ids[word_end]) ||
294            !unicharset->IsSpaceDelimited(unichar_ids[word_end - 1]))) {
295         break;
296       }
297     }
298     float space_cert = 0.0f;
299     if (word_end < num_ids && unichar_ids[word_end] == UNICHAR_SPACE) {
300       space_cert = certs[word_end];
301     }
302     bool leading_space =
303         word_start > 0 && unichar_ids[word_start - 1] == UNICHAR_SPACE;
304     // Create a WERD_RES for the output word.
305     WERD_RES *word_res =
306         InitializeWord(leading_space, line_box, word_start, word_end,
307                        std::min(space_cert, prev_space_cert), unicharset,
308                        xcoords, scale_factor);
309     for (int i = word_start; i < word_end; ++i) {
310       auto *choices = new BLOB_CHOICE_LIST;
311       BLOB_CHOICE_IT bc_it(choices);
312       auto *choice = new BLOB_CHOICE(unichar_ids[i], ratings[i], certs[i], -1,
313                                      1.0f, static_cast<float>(INT16_MAX), 0.0f,
314                                      BCC_STATIC_CLASSIFIER);
315       int col = i - word_start;
316       choice->set_matrix_cell(col, col);
317       bc_it.add_after_then_move(choice);
318       word_res->ratings->put(col, col, choices);
319     }
320     int index = xcoords[word_end - 1];
321     word_res->FakeWordFromRatings(best_nodes[index]->permuter);
322     words->push_back(word_res);
323     prev_space_cert = space_cert;
324     if (word_end < num_ids && unichar_ids[word_end] == UNICHAR_SPACE) {
325       ++word_end;
326     }
327   }
328 }
329 
330 struct greater_than {
operator ()tesseract::greater_than331   inline bool operator()(const RecodeNode *&node1, const RecodeNode *&node2) {
332     return (node1->score > node2->score);
333   }
334 };
335 
PrintBeam2(bool uids,int num_outputs,const UNICHARSET * charset,bool secondary) const336 void RecodeBeamSearch::PrintBeam2(bool uids, int num_outputs,
337                                   const UNICHARSET *charset,
338                                   bool secondary) const {
339   std::vector<std::vector<const RecodeNode *>> topology;
340   std::unordered_set<const RecodeNode *> visited;
341   const std::vector<RecodeBeam *> &beam = !secondary ? beam_ : secondary_beam_;
342   // create the topology
343   for (int step = beam.size() - 1; step >= 0; --step) {
344     std::vector<const RecodeNode *> layer;
345     topology.push_back(layer);
346   }
347   // fill the topology with depths first
348   for (int step = beam.size() - 1; step >= 0; --step) {
349     std::vector<tesseract::RecodePair> &heaps = beam.at(step)->beams_->heap();
350     for (auto node : heaps) {
351       int backtracker = 0;
352       const RecodeNode *curr = &node.data();
353       while (curr != nullptr && !visited.count(curr)) {
354         visited.insert(curr);
355         topology[step - backtracker].push_back(curr);
356         curr = curr->prev;
357         ++backtracker;
358       }
359     }
360   }
361   int ct = 0;
362   unsigned cb = 1;
363   for (const std::vector<const RecodeNode *> &layer : topology) {
364     if (cb >= character_boundaries_.size()) {
365       break;
366     }
367     if (ct == character_boundaries_[cb]) {
368       tprintf("***\n");
369       ++cb;
370     }
371     for (const RecodeNode *node : layer) {
372       const char *code;
373       int intCode;
374       if (node->unichar_id != INVALID_UNICHAR_ID) {
375         code = charset->id_to_unichar(node->unichar_id);
376         intCode = node->unichar_id;
377       } else if (node->code == null_char_) {
378         intCode = 0;
379         code = " ";
380       } else {
381         intCode = 666;
382         code = "*";
383       }
384       int intPrevCode = 0;
385       const char *prevCode;
386       float prevScore = 0;
387       if (node->prev != nullptr) {
388         prevScore = node->prev->score;
389         if (node->prev->unichar_id != INVALID_UNICHAR_ID) {
390           prevCode = charset->id_to_unichar(node->prev->unichar_id);
391           intPrevCode = node->prev->unichar_id;
392         } else if (node->code == null_char_) {
393           intPrevCode = 0;
394           prevCode = " ";
395         } else {
396           prevCode = "*";
397           intPrevCode = 666;
398         }
399       } else {
400         prevCode = " ";
401       }
402       if (uids) {
403         tprintf("%x(|)%f(>)%x(|)%f\n", intPrevCode, prevScore, intCode,
404                 node->score);
405       } else {
406         tprintf("%s(|)%f(>)%s(|)%f\n", prevCode, prevScore, code, node->score);
407       }
408     }
409     tprintf("-\n");
410     ++ct;
411   }
412   tprintf("***\n");
413 }
414 
extractSymbolChoices(const UNICHARSET * unicharset)415 void RecodeBeamSearch::extractSymbolChoices(const UNICHARSET *unicharset) {
416   if (character_boundaries_.size() < 2) {
417     return;
418   }
419   // For the first iteration the original beam is analyzed. After that a
420   // new beam is calculated based on the results from the original beam.
421   std::vector<RecodeBeam *> &currentBeam =
422       secondary_beam_.empty() ? beam_ : secondary_beam_;
423   character_boundaries_[0] = 0;
424   for (unsigned j = 1; j < character_boundaries_.size(); ++j) {
425     std::vector<int> unichar_ids;
426     std::vector<float> certs;
427     std::vector<float> ratings;
428     std::vector<int> xcoords;
429     int backpath = character_boundaries_[j] - character_boundaries_[j - 1];
430     std::vector<tesseract::RecodePair> &heaps =
431         currentBeam.at(character_boundaries_[j] - 1)->beams_->heap();
432     std::vector<const RecodeNode *> best_nodes;
433     std::vector<const RecodeNode *> best;
434     // Scan the segmented node chain for valid unichar ids.
435     for (auto entry : heaps) {
436       bool validChar = false;
437       int backcounter = 0;
438       const RecodeNode *node = &entry.data();
439       while (node != nullptr && backcounter < backpath) {
440         if (node->code != null_char_ &&
441             node->unichar_id != INVALID_UNICHAR_ID) {
442           validChar = true;
443           break;
444         }
445         node = node->prev;
446         ++backcounter;
447       }
448       if (validChar) {
449         best.push_back(&entry.data());
450       }
451     }
452     // find the best rated segmented node chain and extract the unichar id.
453     if (!best.empty()) {
454       std::sort(best.begin(), best.end(), greater_than());
455       ExtractPath(best[0], &best_nodes, backpath);
456       ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings,
457                               &xcoords);
458     }
459     if (!unichar_ids.empty()) {
460       int bestPos = 0;
461       for (unsigned i = 1; i < unichar_ids.size(); ++i) {
462         if (ratings[i] < ratings[bestPos]) {
463           bestPos = i;
464         }
465       }
466 #if 0 // TODO: bestCode is currently unused (see commit 2dd5d0d60).
467       int bestCode = -10;
468       for (auto &node : best_nodes) {
469         if (node->unichar_id == unichar_ids[bestPos]) {
470           bestCode = node->code;
471         }
472       }
473 #endif
474       // Exclude the best choice for the followup decoding.
475       std::unordered_set<int> excludeCodeList;
476       for (auto &best_node : best_nodes) {
477         if (best_node->code != null_char_) {
478           excludeCodeList.insert(best_node->code);
479         }
480       }
481       if (j - 1 < excludedUnichars.size()) {
482         for (auto elem : excludeCodeList) {
483           excludedUnichars[j - 1].insert(elem);
484         }
485       } else {
486         excludedUnichars.push_back(excludeCodeList);
487       }
488       // Save the best choice for the choice iterator.
489       if (j - 1 < ctc_choices.size()) {
490         int id = unichar_ids[bestPos];
491         const char *result = unicharset->id_to_unichar_ext(id);
492         float rating = ratings[bestPos];
493         ctc_choices[j - 1].push_back(
494             std::pair<const char *, float>(result, rating));
495       } else {
496         std::vector<std::pair<const char *, float>> choice;
497         int id = unichar_ids[bestPos];
498         const char *result = unicharset->id_to_unichar_ext(id);
499         float rating = ratings[bestPos];
500         choice.emplace_back(result, rating);
501         ctc_choices.push_back(choice);
502       }
503       // fill the blank spot with an empty array
504     } else {
505       if (j - 1 >= excludedUnichars.size()) {
506         std::unordered_set<int> excludeCodeList;
507         excludedUnichars.push_back(excludeCodeList);
508       }
509       if (j - 1 >= ctc_choices.size()) {
510         std::vector<std::pair<const char *, float>> choice;
511         ctc_choices.push_back(choice);
512       }
513     }
514   }
515   for (auto data : secondary_beam_) {
516     delete data;
517   }
518   secondary_beam_.clear();
519 }
520 
521 // Generates debug output of the content of the beams after a Decode.
DebugBeams(const UNICHARSET & unicharset) const522 void RecodeBeamSearch::DebugBeams(const UNICHARSET &unicharset) const {
523   for (int p = 0; p < beam_size_; ++p) {
524     for (int d = 0; d < 2; ++d) {
525       for (int c = 0; c < NC_COUNT; ++c) {
526         auto cont = static_cast<NodeContinuation>(c);
527         int index = BeamIndex(d, cont, 0);
528         if (beam_[p]->beams_[index].empty()) {
529           continue;
530         }
531         // Print all the best scoring nodes for each unichar found.
532         tprintf("Position %d: %s+%s beam\n", p, d ? "Dict" : "Non-Dict",
533                 kNodeContNames[c]);
534         DebugBeamPos(unicharset, beam_[p]->beams_[index]);
535       }
536     }
537   }
538 }
539 
540 // Generates debug output of the content of a single beam position.
DebugBeamPos(const UNICHARSET & unicharset,const RecodeHeap & heap) const541 void RecodeBeamSearch::DebugBeamPos(const UNICHARSET &unicharset,
542                                     const RecodeHeap &heap) const {
543   std::vector<const RecodeNode *> unichar_bests(unicharset.size());
544   const RecodeNode *null_best = nullptr;
545   int heap_size = heap.size();
546   for (int i = 0; i < heap_size; ++i) {
547     const RecodeNode *node = &heap.get(i).data();
548     if (node->unichar_id == INVALID_UNICHAR_ID) {
549       if (null_best == nullptr || null_best->score < node->score) {
550         null_best = node;
551       }
552     } else {
553       if (unichar_bests[node->unichar_id] == nullptr ||
554           unichar_bests[node->unichar_id]->score < node->score) {
555         unichar_bests[node->unichar_id] = node;
556       }
557     }
558   }
559   for (auto &unichar_best : unichar_bests) {
560     if (unichar_best != nullptr) {
561       const RecodeNode &node = *unichar_best;
562       node.Print(null_char_, unicharset, 1);
563     }
564   }
565   if (null_best != nullptr) {
566     null_best->Print(null_char_, unicharset, 1);
567   }
568 }
569 
570 // Returns the given best_nodes as unichar-ids/certs/ratings/xcoords skipping
571 // duplicates, nulls and intermediate parts.
572 /* static */
ExtractPathAsUnicharIds(const std::vector<const RecodeNode * > & best_nodes,std::vector<int> * unichar_ids,std::vector<float> * certs,std::vector<float> * ratings,std::vector<int> * xcoords,std::vector<int> * character_boundaries)573 void RecodeBeamSearch::ExtractPathAsUnicharIds(
574     const std::vector<const RecodeNode *> &best_nodes,
575     std::vector<int> *unichar_ids, std::vector<float> *certs,
576     std::vector<float> *ratings, std::vector<int> *xcoords,
577     std::vector<int> *character_boundaries) {
578   unichar_ids->clear();
579   certs->clear();
580   ratings->clear();
581   xcoords->clear();
582   std::vector<int> starts;
583   std::vector<int> ends;
584   // Backtrack extracting only valid, non-duplicate unichar-ids.
585   int t = 0;
586   int width = best_nodes.size();
587   while (t < width) {
588     double certainty = 0.0;
589     double rating = 0.0;
590     while (t < width && best_nodes[t]->unichar_id == INVALID_UNICHAR_ID) {
591       double cert = best_nodes[t++]->certainty;
592       if (cert < certainty) {
593         certainty = cert;
594       }
595       rating -= cert;
596     }
597     starts.push_back(t);
598     if (t < width) {
599       int unichar_id = best_nodes[t]->unichar_id;
600       if (unichar_id == UNICHAR_SPACE && !certs->empty() &&
601           best_nodes[t]->permuter != NO_PERM) {
602         // All the rating and certainty go on the previous character except
603         // for the space itself.
604         if (certainty < certs->back()) {
605           certs->back() = certainty;
606         }
607         ratings->back() += rating;
608         certainty = 0.0;
609         rating = 0.0;
610       }
611       unichar_ids->push_back(unichar_id);
612       xcoords->push_back(t);
613       do {
614         double cert = best_nodes[t++]->certainty;
615         // Special-case NO-PERM space to forget the certainty of the previous
616         // nulls. See long comment in ContinueContext.
617         if (cert < certainty || (unichar_id == UNICHAR_SPACE &&
618                                  best_nodes[t - 1]->permuter == NO_PERM)) {
619           certainty = cert;
620         }
621         rating -= cert;
622       } while (t < width && best_nodes[t]->duplicate);
623       ends.push_back(t);
624       certs->push_back(certainty);
625       ratings->push_back(rating);
626     } else if (!certs->empty()) {
627       if (certainty < certs->back()) {
628         certs->back() = certainty;
629       }
630       ratings->back() += rating;
631     }
632   }
633   starts.push_back(width);
634   if (character_boundaries != nullptr) {
635     calculateCharBoundaries(&starts, &ends, character_boundaries, width);
636   }
637   xcoords->push_back(width);
638 }
639 
640 // Sets up a word with the ratings matrix and fake blobs with boxes in the
641 // right places.
InitializeWord(bool leading_space,const TBOX & line_box,int word_start,int word_end,float space_certainty,const UNICHARSET * unicharset,const std::vector<int> & xcoords,float scale_factor)642 WERD_RES *RecodeBeamSearch::InitializeWord(bool leading_space,
643                                            const TBOX &line_box, int word_start,
644                                            int word_end, float space_certainty,
645                                            const UNICHARSET *unicharset,
646                                            const std::vector<int> &xcoords,
647                                            float scale_factor) {
648   // Make a fake blob for each non-zero label.
649   C_BLOB_LIST blobs;
650   C_BLOB_IT b_it(&blobs);
651   for (int i = word_start; i < word_end; ++i) {
652     if (static_cast<unsigned>(i + 1) < character_boundaries_.size()) {
653       TBOX box(static_cast<int16_t>(
654                    std::floor(character_boundaries_[i] * scale_factor)) +
655                    line_box.left(),
656                line_box.bottom(),
657                static_cast<int16_t>(
658                    std::ceil(character_boundaries_[i + 1] * scale_factor)) +
659                    line_box.left(),
660                line_box.top());
661       b_it.add_after_then_move(C_BLOB::FakeBlob(box));
662     }
663   }
664   // Make a fake word from the blobs.
665   WERD *word = new WERD(&blobs, leading_space, nullptr);
666   // Make a WERD_RES from the word.
667   auto *word_res = new WERD_RES(word);
668   word_res->end = word_end - word_start + leading_space;
669   word_res->uch_set = unicharset;
670   word_res->combination = true; // Give it ownership of the word.
671   word_res->space_certainty = space_certainty;
672   word_res->ratings = new MATRIX(word_end - word_start, 1);
673   return word_res;
674 }
675 
676 // Fills top_n_flags_ with bools that are true iff the corresponding output
677 // is one of the top_n.
ComputeTopN(const float * outputs,int num_outputs,int top_n)678 void RecodeBeamSearch::ComputeTopN(const float *outputs, int num_outputs,
679                                    int top_n) {
680   top_n_flags_.clear();
681   top_n_flags_.resize(num_outputs, TN_ALSO_RAN);
682   top_code_ = -1;
683   second_code_ = -1;
684   top_heap_.clear();
685   for (int i = 0; i < num_outputs; ++i) {
686     if (top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key()) {
687       TopPair entry(outputs[i], i);
688       top_heap_.Push(&entry);
689       if (top_heap_.size() > top_n) {
690         top_heap_.Pop(&entry);
691       }
692     }
693   }
694   while (!top_heap_.empty()) {
695     TopPair entry;
696     top_heap_.Pop(&entry);
697     if (top_heap_.size() > 1) {
698       top_n_flags_[entry.data()] = TN_TOPN;
699     } else {
700       top_n_flags_[entry.data()] = TN_TOP2;
701       if (top_heap_.empty()) {
702         top_code_ = entry.data();
703       } else {
704         second_code_ = entry.data();
705       }
706     }
707   }
708   top_n_flags_[null_char_] = TN_TOP2;
709 }
710 
ComputeSecTopN(std::unordered_set<int> * exList,const float * outputs,int num_outputs,int top_n)711 void RecodeBeamSearch::ComputeSecTopN(std::unordered_set<int> *exList,
712                                       const float *outputs, int num_outputs,
713                                       int top_n) {
714   top_n_flags_.clear();
715   top_n_flags_.resize(num_outputs, TN_ALSO_RAN);
716   top_code_ = -1;
717   second_code_ = -1;
718   top_heap_.clear();
719   for (int i = 0; i < num_outputs; ++i) {
720     if ((top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key()) &&
721         !exList->count(i)) {
722       TopPair entry(outputs[i], i);
723       top_heap_.Push(&entry);
724       if (top_heap_.size() > top_n) {
725         top_heap_.Pop(&entry);
726       }
727     }
728   }
729   while (!top_heap_.empty()) {
730     TopPair entry;
731     top_heap_.Pop(&entry);
732     if (top_heap_.size() > 1) {
733       top_n_flags_[entry.data()] = TN_TOPN;
734     } else {
735       top_n_flags_[entry.data()] = TN_TOP2;
736       if (top_heap_.empty()) {
737         top_code_ = entry.data();
738       } else {
739         second_code_ = entry.data();
740       }
741     }
742   }
743   top_n_flags_[null_char_] = TN_TOP2;
744 }
745 
746 // Adds the computation for the current time-step to the beam. Call at each
747 // time-step in sequence from left to right. outputs is the activation vector
748 // for the current timestep.
DecodeStep(const float * outputs,int t,double dict_ratio,double cert_offset,double worst_dict_cert,const UNICHARSET * charset,bool debug)749 void RecodeBeamSearch::DecodeStep(const float *outputs, int t,
750                                   double dict_ratio, double cert_offset,
751                                   double worst_dict_cert,
752                                   const UNICHARSET *charset, bool debug) {
753   if (t == static_cast<int>(beam_.size())) {
754     beam_.push_back(new RecodeBeam);
755   }
756   RecodeBeam *step = beam_[t];
757   beam_size_ = t + 1;
758   step->Clear();
759   if (t == 0) {
760     // The first step can only use singles and initials.
761     ContinueContext(nullptr, BeamIndex(false, NC_ANYTHING, 0), outputs, TN_TOP2,
762                     charset, dict_ratio, cert_offset, worst_dict_cert, step);
763     if (dict_ != nullptr) {
764       ContinueContext(nullptr, BeamIndex(true, NC_ANYTHING, 0), outputs,
765                       TN_TOP2, charset, dict_ratio, cert_offset,
766                       worst_dict_cert, step);
767     }
768   } else {
769     RecodeBeam *prev = beam_[t - 1];
770     if (debug) {
771       int beam_index = BeamIndex(true, NC_ANYTHING, 0);
772       for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
773         std::vector<const RecodeNode *> path;
774         ExtractPath(&prev->beams_[beam_index].get(i).data(), &path);
775         tprintf("Step %d: Dawg beam %d:\n", t, i);
776         DebugPath(charset, path);
777       }
778       beam_index = BeamIndex(false, NC_ANYTHING, 0);
779       for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
780         std::vector<const RecodeNode *> path;
781         ExtractPath(&prev->beams_[beam_index].get(i).data(), &path);
782         tprintf("Step %d: Non-Dawg beam %d:\n", t, i);
783         DebugPath(charset, path);
784       }
785     }
786     int total_beam = 0;
787     // Work through the scores by group (top-2, top-n, the rest) while the beam
788     // is empty. This enables extending the context using only the top-n results
789     // first, which may have an empty intersection with the valid codes, so we
790     // fall back to the rest if the beam is empty.
791     for (int tn = 0; tn < TN_COUNT && total_beam == 0; ++tn) {
792       auto top_n = static_cast<TopNState>(tn);
793       for (int index = 0; index < kNumBeams; ++index) {
794         // Working backwards through the heaps doesn't guarantee that we see the
795         // best first, but it comes before a lot of the worst, so it is slightly
796         // more efficient than going forwards.
797         for (int i = prev->beams_[index].size() - 1; i >= 0; --i) {
798           ContinueContext(&prev->beams_[index].get(i).data(), index, outputs,
799                           top_n, charset, dict_ratio, cert_offset,
800                           worst_dict_cert, step);
801         }
802       }
803       for (int index = 0; index < kNumBeams; ++index) {
804         if (ContinuationFromBeamsIndex(index) == NC_ANYTHING) {
805           total_beam += step->beams_[index].size();
806         }
807       }
808     }
809     // Special case for the best initial dawg. Push it on the heap if good
810     // enough, but there is only one, so it doesn't blow up the beam.
811     for (int c = 0; c < NC_COUNT; ++c) {
812       if (step->best_initial_dawgs_[c].code >= 0) {
813         int index = BeamIndex(true, static_cast<NodeContinuation>(c), 0);
814         RecodeHeap *dawg_heap = &step->beams_[index];
815         PushHeapIfBetter(kBeamWidths[0], &step->best_initial_dawgs_[c],
816                          dawg_heap);
817       }
818     }
819   }
820 }
821 
DecodeSecondaryStep(const float * outputs,int t,double dict_ratio,double cert_offset,double worst_dict_cert,const UNICHARSET * charset,bool debug)822 void RecodeBeamSearch::DecodeSecondaryStep(
823     const float *outputs, int t, double dict_ratio, double cert_offset,
824     double worst_dict_cert, const UNICHARSET *charset, bool debug) {
825   if (t == static_cast<int>(secondary_beam_.size())) {
826     secondary_beam_.push_back(new RecodeBeam);
827   }
828   RecodeBeam *step = secondary_beam_[t];
829   step->Clear();
830   if (t == 0) {
831     // The first step can only use singles and initials.
832     ContinueContext(nullptr, BeamIndex(false, NC_ANYTHING, 0), outputs, TN_TOP2,
833                     charset, dict_ratio, cert_offset, worst_dict_cert, step);
834     if (dict_ != nullptr) {
835       ContinueContext(nullptr, BeamIndex(true, NC_ANYTHING, 0), outputs,
836                       TN_TOP2, charset, dict_ratio, cert_offset,
837                       worst_dict_cert, step);
838     }
839   } else {
840     RecodeBeam *prev = secondary_beam_[t - 1];
841     if (debug) {
842       int beam_index = BeamIndex(true, NC_ANYTHING, 0);
843       for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
844         std::vector<const RecodeNode *> path;
845         ExtractPath(&prev->beams_[beam_index].get(i).data(), &path);
846         tprintf("Step %d: Dawg beam %d:\n", t, i);
847         DebugPath(charset, path);
848       }
849       beam_index = BeamIndex(false, NC_ANYTHING, 0);
850       for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
851         std::vector<const RecodeNode *> path;
852         ExtractPath(&prev->beams_[beam_index].get(i).data(), &path);
853         tprintf("Step %d: Non-Dawg beam %d:\n", t, i);
854         DebugPath(charset, path);
855       }
856     }
857     int total_beam = 0;
858     // Work through the scores by group (top-2, top-n, the rest) while the beam
859     // is empty. This enables extending the context using only the top-n results
860     // first, which may have an empty intersection with the valid codes, so we
861     // fall back to the rest if the beam is empty.
862     for (int tn = 0; tn < TN_COUNT && total_beam == 0; ++tn) {
863       auto top_n = static_cast<TopNState>(tn);
864       for (int index = 0; index < kNumBeams; ++index) {
865         // Working backwards through the heaps doesn't guarantee that we see the
866         // best first, but it comes before a lot of the worst, so it is slightly
867         // more efficient than going forwards.
868         for (int i = prev->beams_[index].size() - 1; i >= 0; --i) {
869           ContinueContext(&prev->beams_[index].get(i).data(), index, outputs,
870                           top_n, charset, dict_ratio, cert_offset,
871                           worst_dict_cert, step);
872         }
873       }
874       for (int index = 0; index < kNumBeams; ++index) {
875         if (ContinuationFromBeamsIndex(index) == NC_ANYTHING) {
876           total_beam += step->beams_[index].size();
877         }
878       }
879     }
880     // Special case for the best initial dawg. Push it on the heap if good
881     // enough, but there is only one, so it doesn't blow up the beam.
882     for (int c = 0; c < NC_COUNT; ++c) {
883       if (step->best_initial_dawgs_[c].code >= 0) {
884         int index = BeamIndex(true, static_cast<NodeContinuation>(c), 0);
885         RecodeHeap *dawg_heap = &step->beams_[index];
886         PushHeapIfBetter(kBeamWidths[0], &step->best_initial_dawgs_[c],
887                          dawg_heap);
888       }
889     }
890   }
891 }
892 
893 // Adds to the appropriate beams the legal (according to recoder)
894 // continuations of context prev, which is of the given length, using the
895 // given network outputs to provide scores to the choices. Uses only those
896 // choices for which top_n_flags[index] == top_n_flag.
ContinueContext(const RecodeNode * prev,int index,const float * outputs,TopNState top_n_flag,const UNICHARSET * charset,double dict_ratio,double cert_offset,double worst_dict_cert,RecodeBeam * step)897 void RecodeBeamSearch::ContinueContext(
898     const RecodeNode *prev, int index, const float *outputs,
899     TopNState top_n_flag, const UNICHARSET *charset, double dict_ratio,
900     double cert_offset, double worst_dict_cert, RecodeBeam *step) {
901   RecodedCharID prefix;
902   RecodedCharID full_code;
903   const RecodeNode *previous = prev;
904   int length = LengthFromBeamsIndex(index);
905   bool use_dawgs = IsDawgFromBeamsIndex(index);
906   NodeContinuation prev_cont = ContinuationFromBeamsIndex(index);
907   for (int p = length - 1; p >= 0; --p, previous = previous->prev) {
908     while (previous != nullptr &&
909            (previous->duplicate || previous->code == null_char_)) {
910       previous = previous->prev;
911     }
912     if (previous != nullptr) {
913       prefix.Set(p, previous->code);
914       full_code.Set(p, previous->code);
915     }
916   }
917   if (prev != nullptr && !is_simple_text_) {
918     if (top_n_flags_[prev->code] == top_n_flag) {
919       if (prev_cont != NC_NO_DUP) {
920         float cert =
921             NetworkIO::ProbToCertainty(outputs[prev->code]) + cert_offset;
922         PushDupOrNoDawgIfBetter(length, true, prev->code, prev->unichar_id,
923                                 cert, worst_dict_cert, dict_ratio, use_dawgs,
924                                 NC_ANYTHING, prev, step);
925       }
926       if (prev_cont == NC_ANYTHING && top_n_flag == TN_TOP2 &&
927           prev->code != null_char_) {
928         float cert = NetworkIO::ProbToCertainty(outputs[prev->code] +
929                                                 outputs[null_char_]) +
930                      cert_offset;
931         PushDupOrNoDawgIfBetter(length, true, prev->code, prev->unichar_id,
932                                 cert, worst_dict_cert, dict_ratio, use_dawgs,
933                                 NC_NO_DUP, prev, step);
934       }
935     }
936     if (prev_cont == NC_ONLY_DUP) {
937       return;
938     }
939     if (prev->code != null_char_ && length > 0 &&
940         top_n_flags_[null_char_] == top_n_flag) {
941       // Allow nulls within multi code sequences, as the nulls within are not
942       // explicitly included in the code sequence.
943       float cert =
944           NetworkIO::ProbToCertainty(outputs[null_char_]) + cert_offset;
945       PushDupOrNoDawgIfBetter(length, false, null_char_, INVALID_UNICHAR_ID,
946                               cert, worst_dict_cert, dict_ratio, use_dawgs,
947                               NC_ANYTHING, prev, step);
948     }
949   }
950   const std::vector<int> *final_codes = recoder_.GetFinalCodes(prefix);
951   if (final_codes != nullptr) {
952     for (int code : *final_codes) {
953       if (top_n_flags_[code] != top_n_flag) {
954         continue;
955       }
956       if (prev != nullptr && prev->code == code && !is_simple_text_) {
957         continue;
958       }
959       float cert = NetworkIO::ProbToCertainty(outputs[code]) + cert_offset;
960       if (cert < kMinCertainty && code != null_char_) {
961         continue;
962       }
963       full_code.Set(length, code);
964       int unichar_id = recoder_.DecodeUnichar(full_code);
965       // Map the null char to INVALID.
966       if (length == 0 && code == null_char_) {
967         unichar_id = INVALID_UNICHAR_ID;
968       }
969       if (unichar_id != INVALID_UNICHAR_ID && charset != nullptr &&
970           !charset->get_enabled(unichar_id)) {
971         continue; // disabled by whitelist/blacklist
972       }
973       ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio,
974                       use_dawgs, NC_ANYTHING, prev, step);
975       if (top_n_flag == TN_TOP2 && code != null_char_) {
976         float prob = outputs[code] + outputs[null_char_];
977         if (prev != nullptr && prev_cont == NC_ANYTHING &&
978             prev->code != null_char_ &&
979             ((prev->code == top_code_ && code == second_code_) ||
980              (code == top_code_ && prev->code == second_code_))) {
981           prob += outputs[prev->code];
982         }
983         cert = NetworkIO::ProbToCertainty(prob) + cert_offset;
984         ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio,
985                         use_dawgs, NC_ONLY_DUP, prev, step);
986       }
987     }
988   }
989   const std::vector<int> *next_codes = recoder_.GetNextCodes(prefix);
990   if (next_codes != nullptr) {
991     for (int code : *next_codes) {
992       if (top_n_flags_[code] != top_n_flag) {
993         continue;
994       }
995       if (prev != nullptr && prev->code == code && !is_simple_text_) {
996         continue;
997       }
998       float cert = NetworkIO::ProbToCertainty(outputs[code]) + cert_offset;
999       PushDupOrNoDawgIfBetter(length + 1, false, code, INVALID_UNICHAR_ID, cert,
1000                               worst_dict_cert, dict_ratio, use_dawgs,
1001                               NC_ANYTHING, prev, step);
1002       if (top_n_flag == TN_TOP2 && code != null_char_) {
1003         float prob = outputs[code] + outputs[null_char_];
1004         if (prev != nullptr && prev_cont == NC_ANYTHING &&
1005             prev->code != null_char_ &&
1006             ((prev->code == top_code_ && code == second_code_) ||
1007              (code == top_code_ && prev->code == second_code_))) {
1008           prob += outputs[prev->code];
1009         }
1010         cert = NetworkIO::ProbToCertainty(prob) + cert_offset;
1011         PushDupOrNoDawgIfBetter(length + 1, false, code, INVALID_UNICHAR_ID,
1012                                 cert, worst_dict_cert, dict_ratio, use_dawgs,
1013                                 NC_ONLY_DUP, prev, step);
1014       }
1015     }
1016   }
1017 }
1018 
1019 // Continues for a new unichar, using dawg or non-dawg as per flag.
ContinueUnichar(int code,int unichar_id,float cert,float worst_dict_cert,float dict_ratio,bool use_dawgs,NodeContinuation cont,const RecodeNode * prev,RecodeBeam * step)1020 void RecodeBeamSearch::ContinueUnichar(int code, int unichar_id, float cert,
1021                                        float worst_dict_cert, float dict_ratio,
1022                                        bool use_dawgs, NodeContinuation cont,
1023                                        const RecodeNode *prev,
1024                                        RecodeBeam *step) {
1025   if (use_dawgs) {
1026     if (cert > worst_dict_cert) {
1027       ContinueDawg(code, unichar_id, cert, cont, prev, step);
1028     }
1029   } else {
1030     RecodeHeap *nodawg_heap = &step->beams_[BeamIndex(false, cont, 0)];
1031     PushHeapIfBetter(kBeamWidths[0], code, unichar_id, TOP_CHOICE_PERM, false,
1032                      false, false, false, cert * dict_ratio, prev, nullptr,
1033                      nodawg_heap);
1034     if (dict_ != nullptr &&
1035         ((unichar_id == UNICHAR_SPACE && cert > worst_dict_cert) ||
1036          !dict_->getUnicharset().IsSpaceDelimited(unichar_id))) {
1037       // Any top choice position that can start a new word, ie a space or
1038       // any non-space-delimited character, should also be considered
1039       // by the dawg search, so push initial dawg to the dawg heap.
1040       float dawg_cert = cert;
1041       PermuterType permuter = TOP_CHOICE_PERM;
1042       // Since we use the space either side of a dictionary word in the
1043       // certainty of the word, (to properly handle weak spaces) and the
1044       // space is coming from a non-dict word, we need special conditions
1045       // to avoid degrading the certainty of the dict word that follows.
1046       // With a space we don't multiply the certainty by dict_ratio, and we
1047       // flag the space with NO_PERM to indicate that we should not use the
1048       // predecessor nulls to generate the confidence for the space, as they
1049       // have already been multiplied by dict_ratio, and we can't go back to
1050       // insert more entries in any previous heaps.
1051       if (unichar_id == UNICHAR_SPACE) {
1052         permuter = NO_PERM;
1053       } else {
1054         dawg_cert *= dict_ratio;
1055       }
1056       PushInitialDawgIfBetter(code, unichar_id, permuter, false, false,
1057                               dawg_cert, cont, prev, step);
1058     }
1059   }
1060 }
1061 
1062 // Adds a RecodeNode composed of the tuple (code, unichar_id, cert, prev,
1063 // appropriate-dawg-args, cert) to the given heap (dawg_beam_) if unichar_id
1064 // is a valid continuation of whatever is in prev.
ContinueDawg(int code,int unichar_id,float cert,NodeContinuation cont,const RecodeNode * prev,RecodeBeam * step)1065 void RecodeBeamSearch::ContinueDawg(int code, int unichar_id, float cert,
1066                                     NodeContinuation cont,
1067                                     const RecodeNode *prev, RecodeBeam *step) {
1068   RecodeHeap *dawg_heap = &step->beams_[BeamIndex(true, cont, 0)];
1069   RecodeHeap *nodawg_heap = &step->beams_[BeamIndex(false, cont, 0)];
1070   if (unichar_id == INVALID_UNICHAR_ID) {
1071     PushHeapIfBetter(kBeamWidths[0], code, unichar_id, NO_PERM, false, false,
1072                      false, false, cert, prev, nullptr, dawg_heap);
1073     return;
1074   }
1075   // Avoid dictionary probe if score a total loss.
1076   float score = cert;
1077   if (prev != nullptr) {
1078     score += prev->score;
1079   }
1080   if (dawg_heap->size() >= kBeamWidths[0] &&
1081       score <= dawg_heap->PeekTop().data().score &&
1082       nodawg_heap->size() >= kBeamWidths[0] &&
1083       score <= nodawg_heap->PeekTop().data().score) {
1084     return;
1085   }
1086   const RecodeNode *uni_prev = prev;
1087   // Prev may be a partial code, null_char, or duplicate, so scan back to the
1088   // last valid unichar_id.
1089   while (uni_prev != nullptr &&
1090          (uni_prev->unichar_id == INVALID_UNICHAR_ID || uni_prev->duplicate)) {
1091     uni_prev = uni_prev->prev;
1092   }
1093   if (unichar_id == UNICHAR_SPACE) {
1094     if (uni_prev != nullptr && uni_prev->end_of_word) {
1095       // Space is good. Push initial state, to the dawg beam and a regular
1096       // space to the top choice beam.
1097       PushInitialDawgIfBetter(code, unichar_id, uni_prev->permuter, false,
1098                               false, cert, cont, prev, step);
1099       PushHeapIfBetter(kBeamWidths[0], code, unichar_id, uni_prev->permuter,
1100                        false, false, false, false, cert, prev, nullptr,
1101                        nodawg_heap);
1102     }
1103     return;
1104   } else if (uni_prev != nullptr && uni_prev->start_of_dawg &&
1105              uni_prev->unichar_id != UNICHAR_SPACE &&
1106              dict_->getUnicharset().IsSpaceDelimited(uni_prev->unichar_id) &&
1107              dict_->getUnicharset().IsSpaceDelimited(unichar_id)) {
1108     return; // Can't break words between space delimited chars.
1109   }
1110   DawgPositionVector initial_dawgs;
1111   auto *updated_dawgs = new DawgPositionVector;
1112   DawgArgs dawg_args(&initial_dawgs, updated_dawgs, NO_PERM);
1113   bool word_start = false;
1114   if (uni_prev == nullptr) {
1115     // Starting from beginning of line.
1116     dict_->default_dawgs(&initial_dawgs, false);
1117     word_start = true;
1118   } else if (uni_prev->dawgs != nullptr) {
1119     // Continuing a previous dict word.
1120     dawg_args.active_dawgs = uni_prev->dawgs;
1121     word_start = uni_prev->start_of_dawg;
1122   } else {
1123     return; // Can't continue if not a dict word.
1124   }
1125   auto permuter = static_cast<PermuterType>(dict_->def_letter_is_okay(
1126       &dawg_args, dict_->getUnicharset(), unichar_id, false));
1127   if (permuter != NO_PERM) {
1128     PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter, false,
1129                      word_start, dawg_args.valid_end, false, cert, prev,
1130                      dawg_args.updated_dawgs, dawg_heap);
1131     if (dawg_args.valid_end && !space_delimited_) {
1132       // We can start another word right away, so push initial state as well,
1133       // to the dawg beam, and the regular character to the top choice beam,
1134       // since non-dict words can start here too.
1135       PushInitialDawgIfBetter(code, unichar_id, permuter, word_start, true,
1136                               cert, cont, prev, step);
1137       PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter, false,
1138                        word_start, true, false, cert, prev, nullptr,
1139                        nodawg_heap);
1140     }
1141   } else {
1142     delete updated_dawgs;
1143   }
1144 }
1145 
1146 // Adds a RecodeNode composed of the tuple (code, unichar_id,
1147 // initial-dawg-state, prev, cert) to the given heap if/ there is room or if
1148 // better than the current worst element if already full.
PushInitialDawgIfBetter(int code,int unichar_id,PermuterType permuter,bool start,bool end,float cert,NodeContinuation cont,const RecodeNode * prev,RecodeBeam * step)1149 void RecodeBeamSearch::PushInitialDawgIfBetter(int code, int unichar_id,
1150                                                PermuterType permuter,
1151                                                bool start, bool end, float cert,
1152                                                NodeContinuation cont,
1153                                                const RecodeNode *prev,
1154                                                RecodeBeam *step) {
1155   RecodeNode *best_initial_dawg = &step->best_initial_dawgs_[cont];
1156   float score = cert;
1157   if (prev != nullptr) {
1158     score += prev->score;
1159   }
1160   if (best_initial_dawg->code < 0 || score > best_initial_dawg->score) {
1161     auto *initial_dawgs = new DawgPositionVector;
1162     dict_->default_dawgs(initial_dawgs, false);
1163     RecodeNode node(code, unichar_id, permuter, true, start, end, false, cert,
1164                     score, prev, initial_dawgs,
1165                     ComputeCodeHash(code, false, prev));
1166     *best_initial_dawg = node;
1167   }
1168 }
1169 
1170 // Adds a RecodeNode composed of the tuple (code, unichar_id, permuter,
1171 // false, false, false, false, cert, prev, nullptr) to heap if there is room
1172 // or if better than the current worst element if already full.
1173 /* static */
PushDupOrNoDawgIfBetter(int length,bool dup,int code,int unichar_id,float cert,float worst_dict_cert,float dict_ratio,bool use_dawgs,NodeContinuation cont,const RecodeNode * prev,RecodeBeam * step)1174 void RecodeBeamSearch::PushDupOrNoDawgIfBetter(
1175     int length, bool dup, int code, int unichar_id, float cert,
1176     float worst_dict_cert, float dict_ratio, bool use_dawgs,
1177     NodeContinuation cont, const RecodeNode *prev, RecodeBeam *step) {
1178   int index = BeamIndex(use_dawgs, cont, length);
1179   if (use_dawgs) {
1180     if (cert > worst_dict_cert) {
1181       PushHeapIfBetter(kBeamWidths[length], code, unichar_id,
1182                        prev ? prev->permuter : NO_PERM, false, false, false,
1183                        dup, cert, prev, nullptr, &step->beams_[index]);
1184     }
1185   } else {
1186     cert *= dict_ratio;
1187     if (cert >= kMinCertainty || code == null_char_) {
1188       PushHeapIfBetter(kBeamWidths[length], code, unichar_id,
1189                        prev ? prev->permuter : TOP_CHOICE_PERM, false, false,
1190                        false, dup, cert, prev, nullptr, &step->beams_[index]);
1191     }
1192   }
1193 }
1194 
1195 // Adds a RecodeNode composed of the tuple (code, unichar_id, permuter,
1196 // dawg_start, word_start, end, dup, cert, prev, d) to heap if there is room
1197 // or if better than the current worst element if already full.
PushHeapIfBetter(int max_size,int code,int unichar_id,PermuterType permuter,bool dawg_start,bool word_start,bool end,bool dup,float cert,const RecodeNode * prev,DawgPositionVector * d,RecodeHeap * heap)1198 void RecodeBeamSearch::PushHeapIfBetter(int max_size, int code, int unichar_id,
1199                                         PermuterType permuter, bool dawg_start,
1200                                         bool word_start, bool end, bool dup,
1201                                         float cert, const RecodeNode *prev,
1202                                         DawgPositionVector *d,
1203                                         RecodeHeap *heap) {
1204   float score = cert;
1205   if (prev != nullptr) {
1206     score += prev->score;
1207   }
1208   if (heap->size() < max_size || score > heap->PeekTop().data().score) {
1209     uint64_t hash = ComputeCodeHash(code, dup, prev);
1210     RecodeNode node(code, unichar_id, permuter, dawg_start, word_start, end,
1211                     dup, cert, score, prev, d, hash);
1212     if (UpdateHeapIfMatched(&node, heap)) {
1213       return;
1214     }
1215     RecodePair entry(score, node);
1216     heap->Push(&entry);
1217     ASSERT_HOST(entry.data().dawgs == nullptr);
1218     if (heap->size() > max_size) {
1219       heap->Pop(&entry);
1220     }
1221   } else {
1222     delete d;
1223   }
1224 }
1225 
1226 // Adds a RecodeNode to heap if there is room
1227 // or if better than the current worst element if already full.
PushHeapIfBetter(int max_size,RecodeNode * node,RecodeHeap * heap)1228 void RecodeBeamSearch::PushHeapIfBetter(int max_size, RecodeNode *node,
1229                                         RecodeHeap *heap) {
1230   if (heap->size() < max_size || node->score > heap->PeekTop().data().score) {
1231     if (UpdateHeapIfMatched(node, heap)) {
1232       return;
1233     }
1234     RecodePair entry(node->score, *node);
1235     heap->Push(&entry);
1236     ASSERT_HOST(entry.data().dawgs == nullptr);
1237     if (heap->size() > max_size) {
1238       heap->Pop(&entry);
1239     }
1240   }
1241 }
1242 
1243 // Searches the heap for a matching entry, and updates the score with
1244 // reshuffle if needed. Returns true if there was a match.
UpdateHeapIfMatched(RecodeNode * new_node,RecodeHeap * heap)1245 bool RecodeBeamSearch::UpdateHeapIfMatched(RecodeNode *new_node,
1246                                            RecodeHeap *heap) {
1247   // TODO(rays) consider hash map instead of linear search.
1248   // It might not be faster because the hash map would have to be updated
1249   // every time a heap reshuffle happens, and that would be a lot of overhead.
1250   std::vector<RecodePair> &nodes = heap->heap();
1251   for (auto &i : nodes) {
1252     RecodeNode &node = i.data();
1253     if (node.code == new_node->code && node.code_hash == new_node->code_hash &&
1254         node.permuter == new_node->permuter &&
1255         node.start_of_dawg == new_node->start_of_dawg) {
1256       if (new_node->score > node.score) {
1257         // The new one is better. Update the entire node in the heap and
1258         // reshuffle.
1259         node = *new_node;
1260         i.key() = node.score;
1261         heap->Reshuffle(&i);
1262       }
1263       return true;
1264     }
1265   }
1266   return false;
1267 }
1268 
1269 // Computes and returns the code-hash for the given code and prev.
ComputeCodeHash(int code,bool dup,const RecodeNode * prev) const1270 uint64_t RecodeBeamSearch::ComputeCodeHash(int code, bool dup,
1271                                            const RecodeNode *prev) const {
1272   uint64_t hash = prev == nullptr ? 0 : prev->code_hash;
1273   if (!dup && code != null_char_) {
1274     int num_classes = recoder_.code_range();
1275     uint64_t carry = (((hash >> 32) * num_classes) >> 32);
1276     hash *= num_classes;
1277     hash += carry;
1278     hash += code;
1279   }
1280   return hash;
1281 }
1282 
1283 // Backtracks to extract the best path through the lattice that was built
1284 // during Decode. On return the best_nodes vector essentially contains the set
1285 // of code, score pairs that make the optimal path with the constraint that
1286 // the recoder can decode the code sequence back to a sequence of unichar-ids.
ExtractBestPaths(std::vector<const RecodeNode * > * best_nodes,std::vector<const RecodeNode * > * second_nodes) const1287 void RecodeBeamSearch::ExtractBestPaths(
1288     std::vector<const RecodeNode *> *best_nodes,
1289     std::vector<const RecodeNode *> *second_nodes) const {
1290   // Scan both beams to extract the best and second best paths.
1291   const RecodeNode *best_node = nullptr;
1292   const RecodeNode *second_best_node = nullptr;
1293   const RecodeBeam *last_beam = beam_[beam_size_ - 1];
1294   for (int c = 0; c < NC_COUNT; ++c) {
1295     if (c == NC_ONLY_DUP) {
1296       continue;
1297     }
1298     auto cont = static_cast<NodeContinuation>(c);
1299     for (int is_dawg = 0; is_dawg < 2; ++is_dawg) {
1300       int beam_index = BeamIndex(is_dawg, cont, 0);
1301       int heap_size = last_beam->beams_[beam_index].size();
1302       for (int h = 0; h < heap_size; ++h) {
1303         const RecodeNode *node = &last_beam->beams_[beam_index].get(h).data();
1304         if (is_dawg) {
1305           // dawg_node may be a null_char, or duplicate, so scan back to the
1306           // last valid unichar_id.
1307           const RecodeNode *dawg_node = node;
1308           while (dawg_node != nullptr &&
1309                  (dawg_node->unichar_id == INVALID_UNICHAR_ID ||
1310                   dawg_node->duplicate)) {
1311             dawg_node = dawg_node->prev;
1312           }
1313           if (dawg_node == nullptr ||
1314               (!dawg_node->end_of_word &&
1315                dawg_node->unichar_id != UNICHAR_SPACE)) {
1316             // Dawg node is not valid.
1317             continue;
1318           }
1319         }
1320         if (best_node == nullptr || node->score > best_node->score) {
1321           second_best_node = best_node;
1322           best_node = node;
1323         } else if (second_best_node == nullptr ||
1324                    node->score > second_best_node->score) {
1325           second_best_node = node;
1326         }
1327       }
1328     }
1329   }
1330   if (second_nodes != nullptr) {
1331     ExtractPath(second_best_node, second_nodes);
1332   }
1333   ExtractPath(best_node, best_nodes);
1334 }
1335 
1336 // Helper backtracks through the lattice from the given node, storing the
1337 // path and reversing it.
ExtractPath(const RecodeNode * node,std::vector<const RecodeNode * > * path) const1338 void RecodeBeamSearch::ExtractPath(
1339     const RecodeNode *node, std::vector<const RecodeNode *> *path) const {
1340   path->clear();
1341   while (node != nullptr) {
1342     path->push_back(node);
1343     node = node->prev;
1344   }
1345   std::reverse(path->begin(), path->end());
1346 }
1347 
ExtractPath(const RecodeNode * node,std::vector<const RecodeNode * > * path,int limiter) const1348 void RecodeBeamSearch::ExtractPath(const RecodeNode *node,
1349                                    std::vector<const RecodeNode *> *path,
1350                                    int limiter) const {
1351   int pathcounter = 0;
1352   path->clear();
1353   while (node != nullptr && pathcounter < limiter) {
1354     path->push_back(node);
1355     node = node->prev;
1356     ++pathcounter;
1357   }
1358   std::reverse(path->begin(), path->end());
1359 }
1360 
1361 // Helper prints debug information on the given lattice path.
DebugPath(const UNICHARSET * unicharset,const std::vector<const RecodeNode * > & path) const1362 void RecodeBeamSearch::DebugPath(
1363     const UNICHARSET *unicharset,
1364     const std::vector<const RecodeNode *> &path) const {
1365   for (unsigned c = 0; c < path.size(); ++c) {
1366     const RecodeNode &node = *path[c];
1367     tprintf("%u ", c);
1368     node.Print(null_char_, *unicharset, 1);
1369   }
1370 }
1371 
1372 // Helper prints debug information on the given unichar path.
DebugUnicharPath(const UNICHARSET * unicharset,const std::vector<const RecodeNode * > & path,const std::vector<int> & unichar_ids,const std::vector<float> & certs,const std::vector<float> & ratings,const std::vector<int> & xcoords) const1373 void RecodeBeamSearch::DebugUnicharPath(
1374     const UNICHARSET *unicharset, const std::vector<const RecodeNode *> &path,
1375     const std::vector<int> &unichar_ids, const std::vector<float> &certs,
1376     const std::vector<float> &ratings, const std::vector<int> &xcoords) const {
1377   auto num_ids = unichar_ids.size();
1378   double total_rating = 0.0;
1379   for (unsigned c = 0; c < num_ids; ++c) {
1380     int coord = xcoords[c];
1381     tprintf("%d %d=%s r=%g, c=%g, s=%d, e=%d, perm=%d\n", coord, unichar_ids[c],
1382             unicharset->debug_str(unichar_ids[c]).c_str(), ratings[c], certs[c],
1383             path[coord]->start_of_word, path[coord]->end_of_word,
1384             path[coord]->permuter);
1385     total_rating += ratings[c];
1386   }
1387   tprintf("Path total rating = %g\n", total_rating);
1388 }
1389 
1390 } // namespace tesseract.
1391