1 ///////////////////////////////////////////////////////////////////////
2 // File: recodebeam.cpp
3 // Description: Beam search to decode from the re-encoded CJK as a sequence of
4 // smaller numbers in place of a single large code.
5 // Author: Ray Smith
6 //
7 // (C) Copyright 2015, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
17 //
18 ///////////////////////////////////////////////////////////////////////
19
20 #include "recodebeam.h"
21
22 #include "networkio.h"
23 #include "pageres.h"
24 #include "unicharcompress.h"
25
26 #include <algorithm> // for std::reverse
27 #include <deque>
28 #include <map>
29 #include <set>
30 #include <tuple>
31 #include <unordered_set>
32 #include <vector>
33
34 namespace tesseract {
35
36 // The beam width at each code position.
37 const int RecodeBeamSearch::kBeamWidths[RecodedCharID::kMaxCodeLen + 1] = {
38 5, 10, 16, 16, 16, 16, 16, 16, 16, 16,
39 };
40
41 static const char *kNodeContNames[] = {"Anything", "OnlyDup", "NoDup"};
42
43 // Prints debug details of the node.
Print(int null_char,const UNICHARSET & unicharset,int depth) const44 void RecodeNode::Print(int null_char, const UNICHARSET &unicharset,
45 int depth) const {
46 if (code == null_char) {
47 tprintf("null_char");
48 } else {
49 tprintf("label=%d, uid=%d=%s", code, unichar_id,
50 unicharset.debug_str(unichar_id).c_str());
51 }
52 tprintf(" score=%g, c=%g,%s%s%s perm=%d, hash=%" PRIx64, score, certainty,
53 start_of_dawg ? " DawgStart" : "", start_of_word ? " Start" : "",
54 end_of_word ? " End" : "", permuter, code_hash);
55 if (depth > 0 && prev != nullptr) {
56 tprintf(" prev:");
57 prev->Print(null_char, unicharset, depth - 1);
58 } else {
59 tprintf("\n");
60 }
61 }
62
63 // Borrows the pointer, which is expected to survive until *this is deleted.
RecodeBeamSearch(const UnicharCompress & recoder,int null_char,bool simple_text,Dict * dict)64 RecodeBeamSearch::RecodeBeamSearch(const UnicharCompress &recoder,
65 int null_char, bool simple_text, Dict *dict)
66 : recoder_(recoder),
67 beam_size_(0),
68 top_code_(-1),
69 second_code_(-1),
70 dict_(dict),
71 space_delimited_(true),
72 is_simple_text_(simple_text),
73 null_char_(null_char) {
74 if (dict_ != nullptr && !dict_->IsSpaceDelimitedLang()) {
75 space_delimited_ = false;
76 }
77 }
78
~RecodeBeamSearch()79 RecodeBeamSearch::~RecodeBeamSearch() {
80 for (auto data : beam_) {
81 delete data;
82 }
83 for (auto data : secondary_beam_) {
84 delete data;
85 }
86 }
87
88 // Decodes the set of network outputs, storing the lattice internally.
Decode(const NetworkIO & output,double dict_ratio,double cert_offset,double worst_dict_cert,const UNICHARSET * charset,int lstm_choice_mode)89 void RecodeBeamSearch::Decode(const NetworkIO &output, double dict_ratio,
90 double cert_offset, double worst_dict_cert,
91 const UNICHARSET *charset, int lstm_choice_mode) {
92 beam_size_ = 0;
93 int width = output.Width();
94 if (lstm_choice_mode) {
95 timesteps.clear();
96 }
97 for (int t = 0; t < width; ++t) {
98 ComputeTopN(output.f(t), output.NumFeatures(), kBeamWidths[0]);
99 DecodeStep(output.f(t), t, dict_ratio, cert_offset, worst_dict_cert,
100 charset);
101 if (lstm_choice_mode) {
102 SaveMostCertainChoices(output.f(t), output.NumFeatures(), charset, t);
103 }
104 }
105 }
Decode(const GENERIC_2D_ARRAY<float> & output,double dict_ratio,double cert_offset,double worst_dict_cert,const UNICHARSET * charset)106 void RecodeBeamSearch::Decode(const GENERIC_2D_ARRAY<float> &output,
107 double dict_ratio, double cert_offset,
108 double worst_dict_cert,
109 const UNICHARSET *charset) {
110 beam_size_ = 0;
111 int width = output.dim1();
112 for (int t = 0; t < width; ++t) {
113 ComputeTopN(output[t], output.dim2(), kBeamWidths[0]);
114 DecodeStep(output[t], t, dict_ratio, cert_offset, worst_dict_cert, charset);
115 }
116 }
117
DecodeSecondaryBeams(const NetworkIO & output,double dict_ratio,double cert_offset,double worst_dict_cert,const UNICHARSET * charset,int lstm_choice_mode)118 void RecodeBeamSearch::DecodeSecondaryBeams(
119 const NetworkIO &output, double dict_ratio, double cert_offset,
120 double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode) {
121 for (auto data : secondary_beam_) {
122 delete data;
123 }
124 secondary_beam_.clear();
125 if (character_boundaries_.size() < 2) {
126 return;
127 }
128 int width = output.Width();
129 unsigned bucketNumber = 0;
130 for (int t = 0; t < width; ++t) {
131 while ((bucketNumber + 1) < character_boundaries_.size() &&
132 t >= character_boundaries_[bucketNumber + 1]) {
133 ++bucketNumber;
134 }
135 ComputeSecTopN(&(excludedUnichars)[bucketNumber], output.f(t),
136 output.NumFeatures(), kBeamWidths[0]);
137 DecodeSecondaryStep(output.f(t), t, dict_ratio, cert_offset,
138 worst_dict_cert, charset);
139 }
140 }
141
SaveMostCertainChoices(const float * outputs,int num_outputs,const UNICHARSET * charset,int xCoord)142 void RecodeBeamSearch::SaveMostCertainChoices(const float *outputs,
143 int num_outputs,
144 const UNICHARSET *charset,
145 int xCoord) {
146 std::vector<std::pair<const char *, float>> choices;
147 for (int i = 0; i < num_outputs; ++i) {
148 if (outputs[i] >= 0.01f) {
149 const char *character;
150 if (i + 2 >= num_outputs) {
151 character = "";
152 } else if (i > 0) {
153 character = charset->id_to_unichar_ext(i + 2);
154 } else {
155 character = charset->id_to_unichar_ext(i);
156 }
157 size_t pos = 0;
158 // order the possible choices within one timestep
159 // beginning with the most likely
160 while (choices.size() > pos && choices[pos].second > outputs[i]) {
161 pos++;
162 }
163 choices.insert(choices.begin() + pos,
164 std::pair<const char *, float>(character, outputs[i]));
165 }
166 }
167 timesteps.push_back(choices);
168 }
169
segmentTimestepsByCharacters()170 void RecodeBeamSearch::segmentTimestepsByCharacters() {
171 for (unsigned i = 1; i < character_boundaries_.size(); ++i) {
172 std::vector<std::vector<std::pair<const char *, float>>> segment;
173 for (int j = character_boundaries_[i - 1]; j < character_boundaries_[i];
174 ++j) {
175 segment.push_back(timesteps[j]);
176 }
177 segmentedTimesteps.push_back(segment);
178 }
179 }
180 std::vector<std::vector<std::pair<const char *, float>>>
combineSegmentedTimesteps(std::vector<std::vector<std::vector<std::pair<const char *,float>>>> * segmentedTimesteps)181 RecodeBeamSearch::combineSegmentedTimesteps(
182 std::vector<std::vector<std::vector<std::pair<const char *, float>>>>
183 *segmentedTimesteps) {
184 std::vector<std::vector<std::pair<const char *, float>>> combined_timesteps;
185 for (auto &segmentedTimestep : *segmentedTimesteps) {
186 for (auto &j : segmentedTimestep) {
187 combined_timesteps.push_back(j);
188 }
189 }
190 return combined_timesteps;
191 }
192
calculateCharBoundaries(std::vector<int> * starts,std::vector<int> * ends,std::vector<int> * char_bounds_,int maxWidth)193 void RecodeBeamSearch::calculateCharBoundaries(std::vector<int> *starts,
194 std::vector<int> *ends,
195 std::vector<int> *char_bounds_,
196 int maxWidth) {
197 char_bounds_->push_back(0);
198 for (unsigned i = 0; i < ends->size(); ++i) {
199 int middle = ((*starts)[i + 1] - (*ends)[i]) / 2;
200 char_bounds_->push_back((*ends)[i] + middle);
201 }
202 char_bounds_->pop_back();
203 char_bounds_->push_back(maxWidth);
204 }
205
206 // Returns the best path as labels/scores/xcoords similar to simple CTC.
ExtractBestPathAsLabels(std::vector<int> * labels,std::vector<int> * xcoords) const207 void RecodeBeamSearch::ExtractBestPathAsLabels(
208 std::vector<int> *labels, std::vector<int> *xcoords) const {
209 labels->clear();
210 xcoords->clear();
211 std::vector<const RecodeNode *> best_nodes;
212 ExtractBestPaths(&best_nodes, nullptr);
213 // Now just run CTC on the best nodes.
214 int t = 0;
215 int width = best_nodes.size();
216 while (t < width) {
217 int label = best_nodes[t]->code;
218 if (label != null_char_) {
219 labels->push_back(label);
220 xcoords->push_back(t);
221 }
222 while (++t < width && !is_simple_text_ && best_nodes[t]->code == label) {
223 }
224 }
225 xcoords->push_back(width);
226 }
227
228 // Returns the best path as unichar-ids/certs/ratings/xcoords skipping
229 // duplicates, nulls and intermediate parts.
ExtractBestPathAsUnicharIds(bool debug,const UNICHARSET * unicharset,std::vector<int> * unichar_ids,std::vector<float> * certs,std::vector<float> * ratings,std::vector<int> * xcoords) const230 void RecodeBeamSearch::ExtractBestPathAsUnicharIds(
231 bool debug, const UNICHARSET *unicharset, std::vector<int> *unichar_ids,
232 std::vector<float> *certs, std::vector<float> *ratings,
233 std::vector<int> *xcoords) const {
234 std::vector<const RecodeNode *> best_nodes;
235 ExtractBestPaths(&best_nodes, nullptr);
236 ExtractPathAsUnicharIds(best_nodes, unichar_ids, certs, ratings, xcoords);
237 if (debug) {
238 DebugPath(unicharset, best_nodes);
239 DebugUnicharPath(unicharset, best_nodes, *unichar_ids, *certs, *ratings,
240 *xcoords);
241 }
242 }
243
244 // Returns the best path as a set of WERD_RES.
ExtractBestPathAsWords(const TBOX & line_box,float scale_factor,bool debug,const UNICHARSET * unicharset,PointerVector<WERD_RES> * words,int lstm_choice_mode)245 void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX &line_box,
246 float scale_factor, bool debug,
247 const UNICHARSET *unicharset,
248 PointerVector<WERD_RES> *words,
249 int lstm_choice_mode) {
250 words->truncate(0);
251 std::vector<int> unichar_ids;
252 std::vector<float> certs;
253 std::vector<float> ratings;
254 std::vector<int> xcoords;
255 std::vector<const RecodeNode *> best_nodes;
256 std::vector<const RecodeNode *> second_nodes;
257 character_boundaries_.clear();
258 ExtractBestPaths(&best_nodes, &second_nodes);
259 if (debug) {
260 DebugPath(unicharset, best_nodes);
261 ExtractPathAsUnicharIds(second_nodes, &unichar_ids, &certs, &ratings,
262 &xcoords);
263 tprintf("\nSecond choice path:\n");
264 DebugUnicharPath(unicharset, second_nodes, unichar_ids, certs, ratings,
265 xcoords);
266 }
267 // If lstm choice mode is required in granularity level 2, it stores the x
268 // Coordinates of every chosen character, to match the alternative choices to
269 // it.
270 ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings, &xcoords,
271 &character_boundaries_);
272 int num_ids = unichar_ids.size();
273 if (debug) {
274 DebugUnicharPath(unicharset, best_nodes, unichar_ids, certs, ratings,
275 xcoords);
276 }
277 // Convert labels to unichar-ids.
278 int word_end = 0;
279 float prev_space_cert = 0.0f;
280 for (int word_start = 0; word_start < num_ids; word_start = word_end) {
281 for (word_end = word_start + 1; word_end < num_ids; ++word_end) {
282 // A word is terminated when a space character or start_of_word flag is
283 // hit. We also want to force a separate word for every non
284 // space-delimited character when not in a dictionary context.
285 if (unichar_ids[word_end] == UNICHAR_SPACE) {
286 break;
287 }
288 int index = xcoords[word_end];
289 if (best_nodes[index]->start_of_word) {
290 break;
291 }
292 if (best_nodes[index]->permuter == TOP_CHOICE_PERM &&
293 (!unicharset->IsSpaceDelimited(unichar_ids[word_end]) ||
294 !unicharset->IsSpaceDelimited(unichar_ids[word_end - 1]))) {
295 break;
296 }
297 }
298 float space_cert = 0.0f;
299 if (word_end < num_ids && unichar_ids[word_end] == UNICHAR_SPACE) {
300 space_cert = certs[word_end];
301 }
302 bool leading_space =
303 word_start > 0 && unichar_ids[word_start - 1] == UNICHAR_SPACE;
304 // Create a WERD_RES for the output word.
305 WERD_RES *word_res =
306 InitializeWord(leading_space, line_box, word_start, word_end,
307 std::min(space_cert, prev_space_cert), unicharset,
308 xcoords, scale_factor);
309 for (int i = word_start; i < word_end; ++i) {
310 auto *choices = new BLOB_CHOICE_LIST;
311 BLOB_CHOICE_IT bc_it(choices);
312 auto *choice = new BLOB_CHOICE(unichar_ids[i], ratings[i], certs[i], -1,
313 1.0f, static_cast<float>(INT16_MAX), 0.0f,
314 BCC_STATIC_CLASSIFIER);
315 int col = i - word_start;
316 choice->set_matrix_cell(col, col);
317 bc_it.add_after_then_move(choice);
318 word_res->ratings->put(col, col, choices);
319 }
320 int index = xcoords[word_end - 1];
321 word_res->FakeWordFromRatings(best_nodes[index]->permuter);
322 words->push_back(word_res);
323 prev_space_cert = space_cert;
324 if (word_end < num_ids && unichar_ids[word_end] == UNICHAR_SPACE) {
325 ++word_end;
326 }
327 }
328 }
329
330 struct greater_than {
operator ()tesseract::greater_than331 inline bool operator()(const RecodeNode *&node1, const RecodeNode *&node2) {
332 return (node1->score > node2->score);
333 }
334 };
335
PrintBeam2(bool uids,int num_outputs,const UNICHARSET * charset,bool secondary) const336 void RecodeBeamSearch::PrintBeam2(bool uids, int num_outputs,
337 const UNICHARSET *charset,
338 bool secondary) const {
339 std::vector<std::vector<const RecodeNode *>> topology;
340 std::unordered_set<const RecodeNode *> visited;
341 const std::vector<RecodeBeam *> &beam = !secondary ? beam_ : secondary_beam_;
342 // create the topology
343 for (int step = beam.size() - 1; step >= 0; --step) {
344 std::vector<const RecodeNode *> layer;
345 topology.push_back(layer);
346 }
347 // fill the topology with depths first
348 for (int step = beam.size() - 1; step >= 0; --step) {
349 std::vector<tesseract::RecodePair> &heaps = beam.at(step)->beams_->heap();
350 for (auto node : heaps) {
351 int backtracker = 0;
352 const RecodeNode *curr = &node.data();
353 while (curr != nullptr && !visited.count(curr)) {
354 visited.insert(curr);
355 topology[step - backtracker].push_back(curr);
356 curr = curr->prev;
357 ++backtracker;
358 }
359 }
360 }
361 int ct = 0;
362 unsigned cb = 1;
363 for (const std::vector<const RecodeNode *> &layer : topology) {
364 if (cb >= character_boundaries_.size()) {
365 break;
366 }
367 if (ct == character_boundaries_[cb]) {
368 tprintf("***\n");
369 ++cb;
370 }
371 for (const RecodeNode *node : layer) {
372 const char *code;
373 int intCode;
374 if (node->unichar_id != INVALID_UNICHAR_ID) {
375 code = charset->id_to_unichar(node->unichar_id);
376 intCode = node->unichar_id;
377 } else if (node->code == null_char_) {
378 intCode = 0;
379 code = " ";
380 } else {
381 intCode = 666;
382 code = "*";
383 }
384 int intPrevCode = 0;
385 const char *prevCode;
386 float prevScore = 0;
387 if (node->prev != nullptr) {
388 prevScore = node->prev->score;
389 if (node->prev->unichar_id != INVALID_UNICHAR_ID) {
390 prevCode = charset->id_to_unichar(node->prev->unichar_id);
391 intPrevCode = node->prev->unichar_id;
392 } else if (node->code == null_char_) {
393 intPrevCode = 0;
394 prevCode = " ";
395 } else {
396 prevCode = "*";
397 intPrevCode = 666;
398 }
399 } else {
400 prevCode = " ";
401 }
402 if (uids) {
403 tprintf("%x(|)%f(>)%x(|)%f\n", intPrevCode, prevScore, intCode,
404 node->score);
405 } else {
406 tprintf("%s(|)%f(>)%s(|)%f\n", prevCode, prevScore, code, node->score);
407 }
408 }
409 tprintf("-\n");
410 ++ct;
411 }
412 tprintf("***\n");
413 }
414
extractSymbolChoices(const UNICHARSET * unicharset)415 void RecodeBeamSearch::extractSymbolChoices(const UNICHARSET *unicharset) {
416 if (character_boundaries_.size() < 2) {
417 return;
418 }
419 // For the first iteration the original beam is analyzed. After that a
420 // new beam is calculated based on the results from the original beam.
421 std::vector<RecodeBeam *> ¤tBeam =
422 secondary_beam_.empty() ? beam_ : secondary_beam_;
423 character_boundaries_[0] = 0;
424 for (unsigned j = 1; j < character_boundaries_.size(); ++j) {
425 std::vector<int> unichar_ids;
426 std::vector<float> certs;
427 std::vector<float> ratings;
428 std::vector<int> xcoords;
429 int backpath = character_boundaries_[j] - character_boundaries_[j - 1];
430 std::vector<tesseract::RecodePair> &heaps =
431 currentBeam.at(character_boundaries_[j] - 1)->beams_->heap();
432 std::vector<const RecodeNode *> best_nodes;
433 std::vector<const RecodeNode *> best;
434 // Scan the segmented node chain for valid unichar ids.
435 for (auto entry : heaps) {
436 bool validChar = false;
437 int backcounter = 0;
438 const RecodeNode *node = &entry.data();
439 while (node != nullptr && backcounter < backpath) {
440 if (node->code != null_char_ &&
441 node->unichar_id != INVALID_UNICHAR_ID) {
442 validChar = true;
443 break;
444 }
445 node = node->prev;
446 ++backcounter;
447 }
448 if (validChar) {
449 best.push_back(&entry.data());
450 }
451 }
452 // find the best rated segmented node chain and extract the unichar id.
453 if (!best.empty()) {
454 std::sort(best.begin(), best.end(), greater_than());
455 ExtractPath(best[0], &best_nodes, backpath);
456 ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings,
457 &xcoords);
458 }
459 if (!unichar_ids.empty()) {
460 int bestPos = 0;
461 for (unsigned i = 1; i < unichar_ids.size(); ++i) {
462 if (ratings[i] < ratings[bestPos]) {
463 bestPos = i;
464 }
465 }
466 #if 0 // TODO: bestCode is currently unused (see commit 2dd5d0d60).
467 int bestCode = -10;
468 for (auto &node : best_nodes) {
469 if (node->unichar_id == unichar_ids[bestPos]) {
470 bestCode = node->code;
471 }
472 }
473 #endif
474 // Exclude the best choice for the followup decoding.
475 std::unordered_set<int> excludeCodeList;
476 for (auto &best_node : best_nodes) {
477 if (best_node->code != null_char_) {
478 excludeCodeList.insert(best_node->code);
479 }
480 }
481 if (j - 1 < excludedUnichars.size()) {
482 for (auto elem : excludeCodeList) {
483 excludedUnichars[j - 1].insert(elem);
484 }
485 } else {
486 excludedUnichars.push_back(excludeCodeList);
487 }
488 // Save the best choice for the choice iterator.
489 if (j - 1 < ctc_choices.size()) {
490 int id = unichar_ids[bestPos];
491 const char *result = unicharset->id_to_unichar_ext(id);
492 float rating = ratings[bestPos];
493 ctc_choices[j - 1].push_back(
494 std::pair<const char *, float>(result, rating));
495 } else {
496 std::vector<std::pair<const char *, float>> choice;
497 int id = unichar_ids[bestPos];
498 const char *result = unicharset->id_to_unichar_ext(id);
499 float rating = ratings[bestPos];
500 choice.emplace_back(result, rating);
501 ctc_choices.push_back(choice);
502 }
503 // fill the blank spot with an empty array
504 } else {
505 if (j - 1 >= excludedUnichars.size()) {
506 std::unordered_set<int> excludeCodeList;
507 excludedUnichars.push_back(excludeCodeList);
508 }
509 if (j - 1 >= ctc_choices.size()) {
510 std::vector<std::pair<const char *, float>> choice;
511 ctc_choices.push_back(choice);
512 }
513 }
514 }
515 for (auto data : secondary_beam_) {
516 delete data;
517 }
518 secondary_beam_.clear();
519 }
520
521 // Generates debug output of the content of the beams after a Decode.
DebugBeams(const UNICHARSET & unicharset) const522 void RecodeBeamSearch::DebugBeams(const UNICHARSET &unicharset) const {
523 for (int p = 0; p < beam_size_; ++p) {
524 for (int d = 0; d < 2; ++d) {
525 for (int c = 0; c < NC_COUNT; ++c) {
526 auto cont = static_cast<NodeContinuation>(c);
527 int index = BeamIndex(d, cont, 0);
528 if (beam_[p]->beams_[index].empty()) {
529 continue;
530 }
531 // Print all the best scoring nodes for each unichar found.
532 tprintf("Position %d: %s+%s beam\n", p, d ? "Dict" : "Non-Dict",
533 kNodeContNames[c]);
534 DebugBeamPos(unicharset, beam_[p]->beams_[index]);
535 }
536 }
537 }
538 }
539
540 // Generates debug output of the content of a single beam position.
DebugBeamPos(const UNICHARSET & unicharset,const RecodeHeap & heap) const541 void RecodeBeamSearch::DebugBeamPos(const UNICHARSET &unicharset,
542 const RecodeHeap &heap) const {
543 std::vector<const RecodeNode *> unichar_bests(unicharset.size());
544 const RecodeNode *null_best = nullptr;
545 int heap_size = heap.size();
546 for (int i = 0; i < heap_size; ++i) {
547 const RecodeNode *node = &heap.get(i).data();
548 if (node->unichar_id == INVALID_UNICHAR_ID) {
549 if (null_best == nullptr || null_best->score < node->score) {
550 null_best = node;
551 }
552 } else {
553 if (unichar_bests[node->unichar_id] == nullptr ||
554 unichar_bests[node->unichar_id]->score < node->score) {
555 unichar_bests[node->unichar_id] = node;
556 }
557 }
558 }
559 for (auto &unichar_best : unichar_bests) {
560 if (unichar_best != nullptr) {
561 const RecodeNode &node = *unichar_best;
562 node.Print(null_char_, unicharset, 1);
563 }
564 }
565 if (null_best != nullptr) {
566 null_best->Print(null_char_, unicharset, 1);
567 }
568 }
569
570 // Returns the given best_nodes as unichar-ids/certs/ratings/xcoords skipping
571 // duplicates, nulls and intermediate parts.
572 /* static */
ExtractPathAsUnicharIds(const std::vector<const RecodeNode * > & best_nodes,std::vector<int> * unichar_ids,std::vector<float> * certs,std::vector<float> * ratings,std::vector<int> * xcoords,std::vector<int> * character_boundaries)573 void RecodeBeamSearch::ExtractPathAsUnicharIds(
574 const std::vector<const RecodeNode *> &best_nodes,
575 std::vector<int> *unichar_ids, std::vector<float> *certs,
576 std::vector<float> *ratings, std::vector<int> *xcoords,
577 std::vector<int> *character_boundaries) {
578 unichar_ids->clear();
579 certs->clear();
580 ratings->clear();
581 xcoords->clear();
582 std::vector<int> starts;
583 std::vector<int> ends;
584 // Backtrack extracting only valid, non-duplicate unichar-ids.
585 int t = 0;
586 int width = best_nodes.size();
587 while (t < width) {
588 double certainty = 0.0;
589 double rating = 0.0;
590 while (t < width && best_nodes[t]->unichar_id == INVALID_UNICHAR_ID) {
591 double cert = best_nodes[t++]->certainty;
592 if (cert < certainty) {
593 certainty = cert;
594 }
595 rating -= cert;
596 }
597 starts.push_back(t);
598 if (t < width) {
599 int unichar_id = best_nodes[t]->unichar_id;
600 if (unichar_id == UNICHAR_SPACE && !certs->empty() &&
601 best_nodes[t]->permuter != NO_PERM) {
602 // All the rating and certainty go on the previous character except
603 // for the space itself.
604 if (certainty < certs->back()) {
605 certs->back() = certainty;
606 }
607 ratings->back() += rating;
608 certainty = 0.0;
609 rating = 0.0;
610 }
611 unichar_ids->push_back(unichar_id);
612 xcoords->push_back(t);
613 do {
614 double cert = best_nodes[t++]->certainty;
615 // Special-case NO-PERM space to forget the certainty of the previous
616 // nulls. See long comment in ContinueContext.
617 if (cert < certainty || (unichar_id == UNICHAR_SPACE &&
618 best_nodes[t - 1]->permuter == NO_PERM)) {
619 certainty = cert;
620 }
621 rating -= cert;
622 } while (t < width && best_nodes[t]->duplicate);
623 ends.push_back(t);
624 certs->push_back(certainty);
625 ratings->push_back(rating);
626 } else if (!certs->empty()) {
627 if (certainty < certs->back()) {
628 certs->back() = certainty;
629 }
630 ratings->back() += rating;
631 }
632 }
633 starts.push_back(width);
634 if (character_boundaries != nullptr) {
635 calculateCharBoundaries(&starts, &ends, character_boundaries, width);
636 }
637 xcoords->push_back(width);
638 }
639
640 // Sets up a word with the ratings matrix and fake blobs with boxes in the
641 // right places.
InitializeWord(bool leading_space,const TBOX & line_box,int word_start,int word_end,float space_certainty,const UNICHARSET * unicharset,const std::vector<int> & xcoords,float scale_factor)642 WERD_RES *RecodeBeamSearch::InitializeWord(bool leading_space,
643 const TBOX &line_box, int word_start,
644 int word_end, float space_certainty,
645 const UNICHARSET *unicharset,
646 const std::vector<int> &xcoords,
647 float scale_factor) {
648 // Make a fake blob for each non-zero label.
649 C_BLOB_LIST blobs;
650 C_BLOB_IT b_it(&blobs);
651 for (int i = word_start; i < word_end; ++i) {
652 if (static_cast<unsigned>(i + 1) < character_boundaries_.size()) {
653 TBOX box(static_cast<int16_t>(
654 std::floor(character_boundaries_[i] * scale_factor)) +
655 line_box.left(),
656 line_box.bottom(),
657 static_cast<int16_t>(
658 std::ceil(character_boundaries_[i + 1] * scale_factor)) +
659 line_box.left(),
660 line_box.top());
661 b_it.add_after_then_move(C_BLOB::FakeBlob(box));
662 }
663 }
664 // Make a fake word from the blobs.
665 WERD *word = new WERD(&blobs, leading_space, nullptr);
666 // Make a WERD_RES from the word.
667 auto *word_res = new WERD_RES(word);
668 word_res->end = word_end - word_start + leading_space;
669 word_res->uch_set = unicharset;
670 word_res->combination = true; // Give it ownership of the word.
671 word_res->space_certainty = space_certainty;
672 word_res->ratings = new MATRIX(word_end - word_start, 1);
673 return word_res;
674 }
675
676 // Fills top_n_flags_ with bools that are true iff the corresponding output
677 // is one of the top_n.
ComputeTopN(const float * outputs,int num_outputs,int top_n)678 void RecodeBeamSearch::ComputeTopN(const float *outputs, int num_outputs,
679 int top_n) {
680 top_n_flags_.clear();
681 top_n_flags_.resize(num_outputs, TN_ALSO_RAN);
682 top_code_ = -1;
683 second_code_ = -1;
684 top_heap_.clear();
685 for (int i = 0; i < num_outputs; ++i) {
686 if (top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key()) {
687 TopPair entry(outputs[i], i);
688 top_heap_.Push(&entry);
689 if (top_heap_.size() > top_n) {
690 top_heap_.Pop(&entry);
691 }
692 }
693 }
694 while (!top_heap_.empty()) {
695 TopPair entry;
696 top_heap_.Pop(&entry);
697 if (top_heap_.size() > 1) {
698 top_n_flags_[entry.data()] = TN_TOPN;
699 } else {
700 top_n_flags_[entry.data()] = TN_TOP2;
701 if (top_heap_.empty()) {
702 top_code_ = entry.data();
703 } else {
704 second_code_ = entry.data();
705 }
706 }
707 }
708 top_n_flags_[null_char_] = TN_TOP2;
709 }
710
ComputeSecTopN(std::unordered_set<int> * exList,const float * outputs,int num_outputs,int top_n)711 void RecodeBeamSearch::ComputeSecTopN(std::unordered_set<int> *exList,
712 const float *outputs, int num_outputs,
713 int top_n) {
714 top_n_flags_.clear();
715 top_n_flags_.resize(num_outputs, TN_ALSO_RAN);
716 top_code_ = -1;
717 second_code_ = -1;
718 top_heap_.clear();
719 for (int i = 0; i < num_outputs; ++i) {
720 if ((top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key()) &&
721 !exList->count(i)) {
722 TopPair entry(outputs[i], i);
723 top_heap_.Push(&entry);
724 if (top_heap_.size() > top_n) {
725 top_heap_.Pop(&entry);
726 }
727 }
728 }
729 while (!top_heap_.empty()) {
730 TopPair entry;
731 top_heap_.Pop(&entry);
732 if (top_heap_.size() > 1) {
733 top_n_flags_[entry.data()] = TN_TOPN;
734 } else {
735 top_n_flags_[entry.data()] = TN_TOP2;
736 if (top_heap_.empty()) {
737 top_code_ = entry.data();
738 } else {
739 second_code_ = entry.data();
740 }
741 }
742 }
743 top_n_flags_[null_char_] = TN_TOP2;
744 }
745
746 // Adds the computation for the current time-step to the beam. Call at each
747 // time-step in sequence from left to right. outputs is the activation vector
748 // for the current timestep.
DecodeStep(const float * outputs,int t,double dict_ratio,double cert_offset,double worst_dict_cert,const UNICHARSET * charset,bool debug)749 void RecodeBeamSearch::DecodeStep(const float *outputs, int t,
750 double dict_ratio, double cert_offset,
751 double worst_dict_cert,
752 const UNICHARSET *charset, bool debug) {
753 if (t == static_cast<int>(beam_.size())) {
754 beam_.push_back(new RecodeBeam);
755 }
756 RecodeBeam *step = beam_[t];
757 beam_size_ = t + 1;
758 step->Clear();
759 if (t == 0) {
760 // The first step can only use singles and initials.
761 ContinueContext(nullptr, BeamIndex(false, NC_ANYTHING, 0), outputs, TN_TOP2,
762 charset, dict_ratio, cert_offset, worst_dict_cert, step);
763 if (dict_ != nullptr) {
764 ContinueContext(nullptr, BeamIndex(true, NC_ANYTHING, 0), outputs,
765 TN_TOP2, charset, dict_ratio, cert_offset,
766 worst_dict_cert, step);
767 }
768 } else {
769 RecodeBeam *prev = beam_[t - 1];
770 if (debug) {
771 int beam_index = BeamIndex(true, NC_ANYTHING, 0);
772 for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
773 std::vector<const RecodeNode *> path;
774 ExtractPath(&prev->beams_[beam_index].get(i).data(), &path);
775 tprintf("Step %d: Dawg beam %d:\n", t, i);
776 DebugPath(charset, path);
777 }
778 beam_index = BeamIndex(false, NC_ANYTHING, 0);
779 for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
780 std::vector<const RecodeNode *> path;
781 ExtractPath(&prev->beams_[beam_index].get(i).data(), &path);
782 tprintf("Step %d: Non-Dawg beam %d:\n", t, i);
783 DebugPath(charset, path);
784 }
785 }
786 int total_beam = 0;
787 // Work through the scores by group (top-2, top-n, the rest) while the beam
788 // is empty. This enables extending the context using only the top-n results
789 // first, which may have an empty intersection with the valid codes, so we
790 // fall back to the rest if the beam is empty.
791 for (int tn = 0; tn < TN_COUNT && total_beam == 0; ++tn) {
792 auto top_n = static_cast<TopNState>(tn);
793 for (int index = 0; index < kNumBeams; ++index) {
794 // Working backwards through the heaps doesn't guarantee that we see the
795 // best first, but it comes before a lot of the worst, so it is slightly
796 // more efficient than going forwards.
797 for (int i = prev->beams_[index].size() - 1; i >= 0; --i) {
798 ContinueContext(&prev->beams_[index].get(i).data(), index, outputs,
799 top_n, charset, dict_ratio, cert_offset,
800 worst_dict_cert, step);
801 }
802 }
803 for (int index = 0; index < kNumBeams; ++index) {
804 if (ContinuationFromBeamsIndex(index) == NC_ANYTHING) {
805 total_beam += step->beams_[index].size();
806 }
807 }
808 }
809 // Special case for the best initial dawg. Push it on the heap if good
810 // enough, but there is only one, so it doesn't blow up the beam.
811 for (int c = 0; c < NC_COUNT; ++c) {
812 if (step->best_initial_dawgs_[c].code >= 0) {
813 int index = BeamIndex(true, static_cast<NodeContinuation>(c), 0);
814 RecodeHeap *dawg_heap = &step->beams_[index];
815 PushHeapIfBetter(kBeamWidths[0], &step->best_initial_dawgs_[c],
816 dawg_heap);
817 }
818 }
819 }
820 }
821
DecodeSecondaryStep(const float * outputs,int t,double dict_ratio,double cert_offset,double worst_dict_cert,const UNICHARSET * charset,bool debug)822 void RecodeBeamSearch::DecodeSecondaryStep(
823 const float *outputs, int t, double dict_ratio, double cert_offset,
824 double worst_dict_cert, const UNICHARSET *charset, bool debug) {
825 if (t == static_cast<int>(secondary_beam_.size())) {
826 secondary_beam_.push_back(new RecodeBeam);
827 }
828 RecodeBeam *step = secondary_beam_[t];
829 step->Clear();
830 if (t == 0) {
831 // The first step can only use singles and initials.
832 ContinueContext(nullptr, BeamIndex(false, NC_ANYTHING, 0), outputs, TN_TOP2,
833 charset, dict_ratio, cert_offset, worst_dict_cert, step);
834 if (dict_ != nullptr) {
835 ContinueContext(nullptr, BeamIndex(true, NC_ANYTHING, 0), outputs,
836 TN_TOP2, charset, dict_ratio, cert_offset,
837 worst_dict_cert, step);
838 }
839 } else {
840 RecodeBeam *prev = secondary_beam_[t - 1];
841 if (debug) {
842 int beam_index = BeamIndex(true, NC_ANYTHING, 0);
843 for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
844 std::vector<const RecodeNode *> path;
845 ExtractPath(&prev->beams_[beam_index].get(i).data(), &path);
846 tprintf("Step %d: Dawg beam %d:\n", t, i);
847 DebugPath(charset, path);
848 }
849 beam_index = BeamIndex(false, NC_ANYTHING, 0);
850 for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
851 std::vector<const RecodeNode *> path;
852 ExtractPath(&prev->beams_[beam_index].get(i).data(), &path);
853 tprintf("Step %d: Non-Dawg beam %d:\n", t, i);
854 DebugPath(charset, path);
855 }
856 }
857 int total_beam = 0;
858 // Work through the scores by group (top-2, top-n, the rest) while the beam
859 // is empty. This enables extending the context using only the top-n results
860 // first, which may have an empty intersection with the valid codes, so we
861 // fall back to the rest if the beam is empty.
862 for (int tn = 0; tn < TN_COUNT && total_beam == 0; ++tn) {
863 auto top_n = static_cast<TopNState>(tn);
864 for (int index = 0; index < kNumBeams; ++index) {
865 // Working backwards through the heaps doesn't guarantee that we see the
866 // best first, but it comes before a lot of the worst, so it is slightly
867 // more efficient than going forwards.
868 for (int i = prev->beams_[index].size() - 1; i >= 0; --i) {
869 ContinueContext(&prev->beams_[index].get(i).data(), index, outputs,
870 top_n, charset, dict_ratio, cert_offset,
871 worst_dict_cert, step);
872 }
873 }
874 for (int index = 0; index < kNumBeams; ++index) {
875 if (ContinuationFromBeamsIndex(index) == NC_ANYTHING) {
876 total_beam += step->beams_[index].size();
877 }
878 }
879 }
880 // Special case for the best initial dawg. Push it on the heap if good
881 // enough, but there is only one, so it doesn't blow up the beam.
882 for (int c = 0; c < NC_COUNT; ++c) {
883 if (step->best_initial_dawgs_[c].code >= 0) {
884 int index = BeamIndex(true, static_cast<NodeContinuation>(c), 0);
885 RecodeHeap *dawg_heap = &step->beams_[index];
886 PushHeapIfBetter(kBeamWidths[0], &step->best_initial_dawgs_[c],
887 dawg_heap);
888 }
889 }
890 }
891 }
892
893 // Adds to the appropriate beams the legal (according to recoder)
894 // continuations of context prev, which is of the given length, using the
895 // given network outputs to provide scores to the choices. Uses only those
896 // choices for which top_n_flags[index] == top_n_flag.
ContinueContext(const RecodeNode * prev,int index,const float * outputs,TopNState top_n_flag,const UNICHARSET * charset,double dict_ratio,double cert_offset,double worst_dict_cert,RecodeBeam * step)897 void RecodeBeamSearch::ContinueContext(
898 const RecodeNode *prev, int index, const float *outputs,
899 TopNState top_n_flag, const UNICHARSET *charset, double dict_ratio,
900 double cert_offset, double worst_dict_cert, RecodeBeam *step) {
901 RecodedCharID prefix;
902 RecodedCharID full_code;
903 const RecodeNode *previous = prev;
904 int length = LengthFromBeamsIndex(index);
905 bool use_dawgs = IsDawgFromBeamsIndex(index);
906 NodeContinuation prev_cont = ContinuationFromBeamsIndex(index);
907 for (int p = length - 1; p >= 0; --p, previous = previous->prev) {
908 while (previous != nullptr &&
909 (previous->duplicate || previous->code == null_char_)) {
910 previous = previous->prev;
911 }
912 if (previous != nullptr) {
913 prefix.Set(p, previous->code);
914 full_code.Set(p, previous->code);
915 }
916 }
917 if (prev != nullptr && !is_simple_text_) {
918 if (top_n_flags_[prev->code] == top_n_flag) {
919 if (prev_cont != NC_NO_DUP) {
920 float cert =
921 NetworkIO::ProbToCertainty(outputs[prev->code]) + cert_offset;
922 PushDupOrNoDawgIfBetter(length, true, prev->code, prev->unichar_id,
923 cert, worst_dict_cert, dict_ratio, use_dawgs,
924 NC_ANYTHING, prev, step);
925 }
926 if (prev_cont == NC_ANYTHING && top_n_flag == TN_TOP2 &&
927 prev->code != null_char_) {
928 float cert = NetworkIO::ProbToCertainty(outputs[prev->code] +
929 outputs[null_char_]) +
930 cert_offset;
931 PushDupOrNoDawgIfBetter(length, true, prev->code, prev->unichar_id,
932 cert, worst_dict_cert, dict_ratio, use_dawgs,
933 NC_NO_DUP, prev, step);
934 }
935 }
936 if (prev_cont == NC_ONLY_DUP) {
937 return;
938 }
939 if (prev->code != null_char_ && length > 0 &&
940 top_n_flags_[null_char_] == top_n_flag) {
941 // Allow nulls within multi code sequences, as the nulls within are not
942 // explicitly included in the code sequence.
943 float cert =
944 NetworkIO::ProbToCertainty(outputs[null_char_]) + cert_offset;
945 PushDupOrNoDawgIfBetter(length, false, null_char_, INVALID_UNICHAR_ID,
946 cert, worst_dict_cert, dict_ratio, use_dawgs,
947 NC_ANYTHING, prev, step);
948 }
949 }
950 const std::vector<int> *final_codes = recoder_.GetFinalCodes(prefix);
951 if (final_codes != nullptr) {
952 for (int code : *final_codes) {
953 if (top_n_flags_[code] != top_n_flag) {
954 continue;
955 }
956 if (prev != nullptr && prev->code == code && !is_simple_text_) {
957 continue;
958 }
959 float cert = NetworkIO::ProbToCertainty(outputs[code]) + cert_offset;
960 if (cert < kMinCertainty && code != null_char_) {
961 continue;
962 }
963 full_code.Set(length, code);
964 int unichar_id = recoder_.DecodeUnichar(full_code);
965 // Map the null char to INVALID.
966 if (length == 0 && code == null_char_) {
967 unichar_id = INVALID_UNICHAR_ID;
968 }
969 if (unichar_id != INVALID_UNICHAR_ID && charset != nullptr &&
970 !charset->get_enabled(unichar_id)) {
971 continue; // disabled by whitelist/blacklist
972 }
973 ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio,
974 use_dawgs, NC_ANYTHING, prev, step);
975 if (top_n_flag == TN_TOP2 && code != null_char_) {
976 float prob = outputs[code] + outputs[null_char_];
977 if (prev != nullptr && prev_cont == NC_ANYTHING &&
978 prev->code != null_char_ &&
979 ((prev->code == top_code_ && code == second_code_) ||
980 (code == top_code_ && prev->code == second_code_))) {
981 prob += outputs[prev->code];
982 }
983 cert = NetworkIO::ProbToCertainty(prob) + cert_offset;
984 ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio,
985 use_dawgs, NC_ONLY_DUP, prev, step);
986 }
987 }
988 }
989 const std::vector<int> *next_codes = recoder_.GetNextCodes(prefix);
990 if (next_codes != nullptr) {
991 for (int code : *next_codes) {
992 if (top_n_flags_[code] != top_n_flag) {
993 continue;
994 }
995 if (prev != nullptr && prev->code == code && !is_simple_text_) {
996 continue;
997 }
998 float cert = NetworkIO::ProbToCertainty(outputs[code]) + cert_offset;
999 PushDupOrNoDawgIfBetter(length + 1, false, code, INVALID_UNICHAR_ID, cert,
1000 worst_dict_cert, dict_ratio, use_dawgs,
1001 NC_ANYTHING, prev, step);
1002 if (top_n_flag == TN_TOP2 && code != null_char_) {
1003 float prob = outputs[code] + outputs[null_char_];
1004 if (prev != nullptr && prev_cont == NC_ANYTHING &&
1005 prev->code != null_char_ &&
1006 ((prev->code == top_code_ && code == second_code_) ||
1007 (code == top_code_ && prev->code == second_code_))) {
1008 prob += outputs[prev->code];
1009 }
1010 cert = NetworkIO::ProbToCertainty(prob) + cert_offset;
1011 PushDupOrNoDawgIfBetter(length + 1, false, code, INVALID_UNICHAR_ID,
1012 cert, worst_dict_cert, dict_ratio, use_dawgs,
1013 NC_ONLY_DUP, prev, step);
1014 }
1015 }
1016 }
1017 }
1018
1019 // Continues for a new unichar, using dawg or non-dawg as per flag.
ContinueUnichar(int code,int unichar_id,float cert,float worst_dict_cert,float dict_ratio,bool use_dawgs,NodeContinuation cont,const RecodeNode * prev,RecodeBeam * step)1020 void RecodeBeamSearch::ContinueUnichar(int code, int unichar_id, float cert,
1021 float worst_dict_cert, float dict_ratio,
1022 bool use_dawgs, NodeContinuation cont,
1023 const RecodeNode *prev,
1024 RecodeBeam *step) {
1025 if (use_dawgs) {
1026 if (cert > worst_dict_cert) {
1027 ContinueDawg(code, unichar_id, cert, cont, prev, step);
1028 }
1029 } else {
1030 RecodeHeap *nodawg_heap = &step->beams_[BeamIndex(false, cont, 0)];
1031 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, TOP_CHOICE_PERM, false,
1032 false, false, false, cert * dict_ratio, prev, nullptr,
1033 nodawg_heap);
1034 if (dict_ != nullptr &&
1035 ((unichar_id == UNICHAR_SPACE && cert > worst_dict_cert) ||
1036 !dict_->getUnicharset().IsSpaceDelimited(unichar_id))) {
1037 // Any top choice position that can start a new word, ie a space or
1038 // any non-space-delimited character, should also be considered
1039 // by the dawg search, so push initial dawg to the dawg heap.
1040 float dawg_cert = cert;
1041 PermuterType permuter = TOP_CHOICE_PERM;
1042 // Since we use the space either side of a dictionary word in the
1043 // certainty of the word, (to properly handle weak spaces) and the
1044 // space is coming from a non-dict word, we need special conditions
1045 // to avoid degrading the certainty of the dict word that follows.
1046 // With a space we don't multiply the certainty by dict_ratio, and we
1047 // flag the space with NO_PERM to indicate that we should not use the
1048 // predecessor nulls to generate the confidence for the space, as they
1049 // have already been multiplied by dict_ratio, and we can't go back to
1050 // insert more entries in any previous heaps.
1051 if (unichar_id == UNICHAR_SPACE) {
1052 permuter = NO_PERM;
1053 } else {
1054 dawg_cert *= dict_ratio;
1055 }
1056 PushInitialDawgIfBetter(code, unichar_id, permuter, false, false,
1057 dawg_cert, cont, prev, step);
1058 }
1059 }
1060 }
1061
1062 // Adds a RecodeNode composed of the tuple (code, unichar_id, cert, prev,
1063 // appropriate-dawg-args, cert) to the given heap (dawg_beam_) if unichar_id
1064 // is a valid continuation of whatever is in prev.
ContinueDawg(int code,int unichar_id,float cert,NodeContinuation cont,const RecodeNode * prev,RecodeBeam * step)1065 void RecodeBeamSearch::ContinueDawg(int code, int unichar_id, float cert,
1066 NodeContinuation cont,
1067 const RecodeNode *prev, RecodeBeam *step) {
1068 RecodeHeap *dawg_heap = &step->beams_[BeamIndex(true, cont, 0)];
1069 RecodeHeap *nodawg_heap = &step->beams_[BeamIndex(false, cont, 0)];
1070 if (unichar_id == INVALID_UNICHAR_ID) {
1071 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, NO_PERM, false, false,
1072 false, false, cert, prev, nullptr, dawg_heap);
1073 return;
1074 }
1075 // Avoid dictionary probe if score a total loss.
1076 float score = cert;
1077 if (prev != nullptr) {
1078 score += prev->score;
1079 }
1080 if (dawg_heap->size() >= kBeamWidths[0] &&
1081 score <= dawg_heap->PeekTop().data().score &&
1082 nodawg_heap->size() >= kBeamWidths[0] &&
1083 score <= nodawg_heap->PeekTop().data().score) {
1084 return;
1085 }
1086 const RecodeNode *uni_prev = prev;
1087 // Prev may be a partial code, null_char, or duplicate, so scan back to the
1088 // last valid unichar_id.
1089 while (uni_prev != nullptr &&
1090 (uni_prev->unichar_id == INVALID_UNICHAR_ID || uni_prev->duplicate)) {
1091 uni_prev = uni_prev->prev;
1092 }
1093 if (unichar_id == UNICHAR_SPACE) {
1094 if (uni_prev != nullptr && uni_prev->end_of_word) {
1095 // Space is good. Push initial state, to the dawg beam and a regular
1096 // space to the top choice beam.
1097 PushInitialDawgIfBetter(code, unichar_id, uni_prev->permuter, false,
1098 false, cert, cont, prev, step);
1099 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, uni_prev->permuter,
1100 false, false, false, false, cert, prev, nullptr,
1101 nodawg_heap);
1102 }
1103 return;
1104 } else if (uni_prev != nullptr && uni_prev->start_of_dawg &&
1105 uni_prev->unichar_id != UNICHAR_SPACE &&
1106 dict_->getUnicharset().IsSpaceDelimited(uni_prev->unichar_id) &&
1107 dict_->getUnicharset().IsSpaceDelimited(unichar_id)) {
1108 return; // Can't break words between space delimited chars.
1109 }
1110 DawgPositionVector initial_dawgs;
1111 auto *updated_dawgs = new DawgPositionVector;
1112 DawgArgs dawg_args(&initial_dawgs, updated_dawgs, NO_PERM);
1113 bool word_start = false;
1114 if (uni_prev == nullptr) {
1115 // Starting from beginning of line.
1116 dict_->default_dawgs(&initial_dawgs, false);
1117 word_start = true;
1118 } else if (uni_prev->dawgs != nullptr) {
1119 // Continuing a previous dict word.
1120 dawg_args.active_dawgs = uni_prev->dawgs;
1121 word_start = uni_prev->start_of_dawg;
1122 } else {
1123 return; // Can't continue if not a dict word.
1124 }
1125 auto permuter = static_cast<PermuterType>(dict_->def_letter_is_okay(
1126 &dawg_args, dict_->getUnicharset(), unichar_id, false));
1127 if (permuter != NO_PERM) {
1128 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter, false,
1129 word_start, dawg_args.valid_end, false, cert, prev,
1130 dawg_args.updated_dawgs, dawg_heap);
1131 if (dawg_args.valid_end && !space_delimited_) {
1132 // We can start another word right away, so push initial state as well,
1133 // to the dawg beam, and the regular character to the top choice beam,
1134 // since non-dict words can start here too.
1135 PushInitialDawgIfBetter(code, unichar_id, permuter, word_start, true,
1136 cert, cont, prev, step);
1137 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter, false,
1138 word_start, true, false, cert, prev, nullptr,
1139 nodawg_heap);
1140 }
1141 } else {
1142 delete updated_dawgs;
1143 }
1144 }
1145
1146 // Adds a RecodeNode composed of the tuple (code, unichar_id,
1147 // initial-dawg-state, prev, cert) to the given heap if/ there is room or if
1148 // better than the current worst element if already full.
PushInitialDawgIfBetter(int code,int unichar_id,PermuterType permuter,bool start,bool end,float cert,NodeContinuation cont,const RecodeNode * prev,RecodeBeam * step)1149 void RecodeBeamSearch::PushInitialDawgIfBetter(int code, int unichar_id,
1150 PermuterType permuter,
1151 bool start, bool end, float cert,
1152 NodeContinuation cont,
1153 const RecodeNode *prev,
1154 RecodeBeam *step) {
1155 RecodeNode *best_initial_dawg = &step->best_initial_dawgs_[cont];
1156 float score = cert;
1157 if (prev != nullptr) {
1158 score += prev->score;
1159 }
1160 if (best_initial_dawg->code < 0 || score > best_initial_dawg->score) {
1161 auto *initial_dawgs = new DawgPositionVector;
1162 dict_->default_dawgs(initial_dawgs, false);
1163 RecodeNode node(code, unichar_id, permuter, true, start, end, false, cert,
1164 score, prev, initial_dawgs,
1165 ComputeCodeHash(code, false, prev));
1166 *best_initial_dawg = node;
1167 }
1168 }
1169
1170 // Adds a RecodeNode composed of the tuple (code, unichar_id, permuter,
1171 // false, false, false, false, cert, prev, nullptr) to heap if there is room
1172 // or if better than the current worst element if already full.
1173 /* static */
PushDupOrNoDawgIfBetter(int length,bool dup,int code,int unichar_id,float cert,float worst_dict_cert,float dict_ratio,bool use_dawgs,NodeContinuation cont,const RecodeNode * prev,RecodeBeam * step)1174 void RecodeBeamSearch::PushDupOrNoDawgIfBetter(
1175 int length, bool dup, int code, int unichar_id, float cert,
1176 float worst_dict_cert, float dict_ratio, bool use_dawgs,
1177 NodeContinuation cont, const RecodeNode *prev, RecodeBeam *step) {
1178 int index = BeamIndex(use_dawgs, cont, length);
1179 if (use_dawgs) {
1180 if (cert > worst_dict_cert) {
1181 PushHeapIfBetter(kBeamWidths[length], code, unichar_id,
1182 prev ? prev->permuter : NO_PERM, false, false, false,
1183 dup, cert, prev, nullptr, &step->beams_[index]);
1184 }
1185 } else {
1186 cert *= dict_ratio;
1187 if (cert >= kMinCertainty || code == null_char_) {
1188 PushHeapIfBetter(kBeamWidths[length], code, unichar_id,
1189 prev ? prev->permuter : TOP_CHOICE_PERM, false, false,
1190 false, dup, cert, prev, nullptr, &step->beams_[index]);
1191 }
1192 }
1193 }
1194
1195 // Adds a RecodeNode composed of the tuple (code, unichar_id, permuter,
1196 // dawg_start, word_start, end, dup, cert, prev, d) to heap if there is room
1197 // or if better than the current worst element if already full.
PushHeapIfBetter(int max_size,int code,int unichar_id,PermuterType permuter,bool dawg_start,bool word_start,bool end,bool dup,float cert,const RecodeNode * prev,DawgPositionVector * d,RecodeHeap * heap)1198 void RecodeBeamSearch::PushHeapIfBetter(int max_size, int code, int unichar_id,
1199 PermuterType permuter, bool dawg_start,
1200 bool word_start, bool end, bool dup,
1201 float cert, const RecodeNode *prev,
1202 DawgPositionVector *d,
1203 RecodeHeap *heap) {
1204 float score = cert;
1205 if (prev != nullptr) {
1206 score += prev->score;
1207 }
1208 if (heap->size() < max_size || score > heap->PeekTop().data().score) {
1209 uint64_t hash = ComputeCodeHash(code, dup, prev);
1210 RecodeNode node(code, unichar_id, permuter, dawg_start, word_start, end,
1211 dup, cert, score, prev, d, hash);
1212 if (UpdateHeapIfMatched(&node, heap)) {
1213 return;
1214 }
1215 RecodePair entry(score, node);
1216 heap->Push(&entry);
1217 ASSERT_HOST(entry.data().dawgs == nullptr);
1218 if (heap->size() > max_size) {
1219 heap->Pop(&entry);
1220 }
1221 } else {
1222 delete d;
1223 }
1224 }
1225
1226 // Adds a RecodeNode to heap if there is room
1227 // or if better than the current worst element if already full.
PushHeapIfBetter(int max_size,RecodeNode * node,RecodeHeap * heap)1228 void RecodeBeamSearch::PushHeapIfBetter(int max_size, RecodeNode *node,
1229 RecodeHeap *heap) {
1230 if (heap->size() < max_size || node->score > heap->PeekTop().data().score) {
1231 if (UpdateHeapIfMatched(node, heap)) {
1232 return;
1233 }
1234 RecodePair entry(node->score, *node);
1235 heap->Push(&entry);
1236 ASSERT_HOST(entry.data().dawgs == nullptr);
1237 if (heap->size() > max_size) {
1238 heap->Pop(&entry);
1239 }
1240 }
1241 }
1242
1243 // Searches the heap for a matching entry, and updates the score with
1244 // reshuffle if needed. Returns true if there was a match.
UpdateHeapIfMatched(RecodeNode * new_node,RecodeHeap * heap)1245 bool RecodeBeamSearch::UpdateHeapIfMatched(RecodeNode *new_node,
1246 RecodeHeap *heap) {
1247 // TODO(rays) consider hash map instead of linear search.
1248 // It might not be faster because the hash map would have to be updated
1249 // every time a heap reshuffle happens, and that would be a lot of overhead.
1250 std::vector<RecodePair> &nodes = heap->heap();
1251 for (auto &i : nodes) {
1252 RecodeNode &node = i.data();
1253 if (node.code == new_node->code && node.code_hash == new_node->code_hash &&
1254 node.permuter == new_node->permuter &&
1255 node.start_of_dawg == new_node->start_of_dawg) {
1256 if (new_node->score > node.score) {
1257 // The new one is better. Update the entire node in the heap and
1258 // reshuffle.
1259 node = *new_node;
1260 i.key() = node.score;
1261 heap->Reshuffle(&i);
1262 }
1263 return true;
1264 }
1265 }
1266 return false;
1267 }
1268
1269 // Computes and returns the code-hash for the given code and prev.
ComputeCodeHash(int code,bool dup,const RecodeNode * prev) const1270 uint64_t RecodeBeamSearch::ComputeCodeHash(int code, bool dup,
1271 const RecodeNode *prev) const {
1272 uint64_t hash = prev == nullptr ? 0 : prev->code_hash;
1273 if (!dup && code != null_char_) {
1274 int num_classes = recoder_.code_range();
1275 uint64_t carry = (((hash >> 32) * num_classes) >> 32);
1276 hash *= num_classes;
1277 hash += carry;
1278 hash += code;
1279 }
1280 return hash;
1281 }
1282
1283 // Backtracks to extract the best path through the lattice that was built
1284 // during Decode. On return the best_nodes vector essentially contains the set
1285 // of code, score pairs that make the optimal path with the constraint that
1286 // the recoder can decode the code sequence back to a sequence of unichar-ids.
ExtractBestPaths(std::vector<const RecodeNode * > * best_nodes,std::vector<const RecodeNode * > * second_nodes) const1287 void RecodeBeamSearch::ExtractBestPaths(
1288 std::vector<const RecodeNode *> *best_nodes,
1289 std::vector<const RecodeNode *> *second_nodes) const {
1290 // Scan both beams to extract the best and second best paths.
1291 const RecodeNode *best_node = nullptr;
1292 const RecodeNode *second_best_node = nullptr;
1293 const RecodeBeam *last_beam = beam_[beam_size_ - 1];
1294 for (int c = 0; c < NC_COUNT; ++c) {
1295 if (c == NC_ONLY_DUP) {
1296 continue;
1297 }
1298 auto cont = static_cast<NodeContinuation>(c);
1299 for (int is_dawg = 0; is_dawg < 2; ++is_dawg) {
1300 int beam_index = BeamIndex(is_dawg, cont, 0);
1301 int heap_size = last_beam->beams_[beam_index].size();
1302 for (int h = 0; h < heap_size; ++h) {
1303 const RecodeNode *node = &last_beam->beams_[beam_index].get(h).data();
1304 if (is_dawg) {
1305 // dawg_node may be a null_char, or duplicate, so scan back to the
1306 // last valid unichar_id.
1307 const RecodeNode *dawg_node = node;
1308 while (dawg_node != nullptr &&
1309 (dawg_node->unichar_id == INVALID_UNICHAR_ID ||
1310 dawg_node->duplicate)) {
1311 dawg_node = dawg_node->prev;
1312 }
1313 if (dawg_node == nullptr ||
1314 (!dawg_node->end_of_word &&
1315 dawg_node->unichar_id != UNICHAR_SPACE)) {
1316 // Dawg node is not valid.
1317 continue;
1318 }
1319 }
1320 if (best_node == nullptr || node->score > best_node->score) {
1321 second_best_node = best_node;
1322 best_node = node;
1323 } else if (second_best_node == nullptr ||
1324 node->score > second_best_node->score) {
1325 second_best_node = node;
1326 }
1327 }
1328 }
1329 }
1330 if (second_nodes != nullptr) {
1331 ExtractPath(second_best_node, second_nodes);
1332 }
1333 ExtractPath(best_node, best_nodes);
1334 }
1335
1336 // Helper backtracks through the lattice from the given node, storing the
1337 // path and reversing it.
ExtractPath(const RecodeNode * node,std::vector<const RecodeNode * > * path) const1338 void RecodeBeamSearch::ExtractPath(
1339 const RecodeNode *node, std::vector<const RecodeNode *> *path) const {
1340 path->clear();
1341 while (node != nullptr) {
1342 path->push_back(node);
1343 node = node->prev;
1344 }
1345 std::reverse(path->begin(), path->end());
1346 }
1347
ExtractPath(const RecodeNode * node,std::vector<const RecodeNode * > * path,int limiter) const1348 void RecodeBeamSearch::ExtractPath(const RecodeNode *node,
1349 std::vector<const RecodeNode *> *path,
1350 int limiter) const {
1351 int pathcounter = 0;
1352 path->clear();
1353 while (node != nullptr && pathcounter < limiter) {
1354 path->push_back(node);
1355 node = node->prev;
1356 ++pathcounter;
1357 }
1358 std::reverse(path->begin(), path->end());
1359 }
1360
1361 // Helper prints debug information on the given lattice path.
DebugPath(const UNICHARSET * unicharset,const std::vector<const RecodeNode * > & path) const1362 void RecodeBeamSearch::DebugPath(
1363 const UNICHARSET *unicharset,
1364 const std::vector<const RecodeNode *> &path) const {
1365 for (unsigned c = 0; c < path.size(); ++c) {
1366 const RecodeNode &node = *path[c];
1367 tprintf("%u ", c);
1368 node.Print(null_char_, *unicharset, 1);
1369 }
1370 }
1371
1372 // Helper prints debug information on the given unichar path.
DebugUnicharPath(const UNICHARSET * unicharset,const std::vector<const RecodeNode * > & path,const std::vector<int> & unichar_ids,const std::vector<float> & certs,const std::vector<float> & ratings,const std::vector<int> & xcoords) const1373 void RecodeBeamSearch::DebugUnicharPath(
1374 const UNICHARSET *unicharset, const std::vector<const RecodeNode *> &path,
1375 const std::vector<int> &unichar_ids, const std::vector<float> &certs,
1376 const std::vector<float> &ratings, const std::vector<int> &xcoords) const {
1377 auto num_ids = unichar_ids.size();
1378 double total_rating = 0.0;
1379 for (unsigned c = 0; c < num_ids; ++c) {
1380 int coord = xcoords[c];
1381 tprintf("%d %d=%s r=%g, c=%g, s=%d, e=%d, perm=%d\n", coord, unichar_ids[c],
1382 unicharset->debug_str(unichar_ids[c]).c_str(), ratings[c], certs[c],
1383 path[coord]->start_of_word, path[coord]->end_of_word,
1384 path[coord]->permuter);
1385 total_rating += ratings[c];
1386 }
1387 tprintf("Path total rating = %g\n", total_rating);
1388 }
1389
1390 } // namespace tesseract.
1391