1 /////////////////////////////////////////////////////////////////////// 2 // File: recodebeam.h 3 // Description: Beam search to decode from the re-encoded CJK as a sequence of 4 // smaller numbers in place of a single large code. 5 // Author: Ray Smith 6 // 7 // (C) Copyright 2015, Google Inc. 8 // Licensed under the Apache License, Version 2.0 (the "License"); 9 // you may not use this file except in compliance with the License. 10 // You may obtain a copy of the License at 11 // http://www.apache.org/licenses/LICENSE-2.0 12 // Unless required by applicable law or agreed to in writing, software 13 // distributed under the License is distributed on an "AS IS" BASIS, 14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 // See the License for the specific language governing permissions and 16 // limitations under the License. 17 // 18 /////////////////////////////////////////////////////////////////////// 19 20 #ifndef THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_ 21 #define THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_ 22 23 #include "dawg.h" 24 #include "dict.h" 25 #include "genericheap.h" 26 #include "genericvector.h" 27 #include "kdpair.h" 28 #include "networkio.h" 29 #include "ratngs.h" 30 #include "unicharcompress.h" 31 32 #include <deque> 33 #include <set> 34 #include <tuple> 35 #include <unordered_set> 36 #include <vector> 37 38 namespace tesseract { 39 40 // Enum describing what can follow the current node. 41 // Consider the following softmax outputs: 42 // Timestep 0 1 2 3 4 5 6 7 8 43 // X-score 0.01 0.55 0.98 0.42 0.01 0.01 0.40 0.95 0.01 44 // Y-score 0.00 0.01 0.01 0.01 0.01 0.97 0.59 0.04 0.01 45 // Null-score 0.99 0.44 0.01 0.57 0.98 0.02 0.01 0.01 0.98 46 // Then the correct CTC decoding (in which adjacent equal classes are folded, 47 // and then all nulls are dropped) is clearly XYX, but simple decoding (taking 48 // the max at each timestep) leads to: 49 // Null@0.99 X@0.55 X@0.98 Null@0.57 Null@0.98 Y@0.97 Y@0.59 X@0.95 Null@0.98, 50 // which folds to the correct XYX. The conversion to Tesseract rating and 51 // certainty uses the sum of the log probs (log of the product of probabilities) 52 // for the Rating and the minimum log prob for the certainty, but that yields a 53 // minimum certainty of log(0.55), which is poor for such an obvious case. 54 // CTC says that the probability of the result is the SUM of the products of the 55 // probabilities over ALL PATHS that decode to the same result, which includes: 56 // NXXNNYYXN, NNXNNYYN, NXXXNYYXN, NNXXNYXXN, and others including XXXXXYYXX. 57 // That is intractable, so some compromise between simple and ideal is needed. 58 // Observing that evenly split timesteps rarely happen next to each other, we 59 // allow scores at a transition between classes to be added for decoding thus: 60 // N@0.99 (N+X)@0.99 X@0.98 (N+X)@0.99 N@0.98 Y@0.97 (X+Y+N)@1.00 X@0.95 N@0.98. 61 // This works because NNX and NXX both decode to X, so in the middle we can use 62 // N+X. Note that the classes either side of a sum must stand alone, i.e. use a 63 // single score, to force all paths to pass through them and decode to the same 64 // result. Also in the special case of a transition from X to Y, with only one 65 // timestep between, it is possible to add X+Y+N, since XXY, XYY, and XNY all 66 // decode to XY. 67 // An important condition is that we cannot combine X and Null between two 68 // stand-alone Xs, since that can decode as XNX->XX or XXX->X, so the scores for 69 // X and Null have to go in separate paths. Combining scores in this way 70 // provides a much better minimum certainty of log(0.95). 71 // In the implementation of the beam search, we have to place the possibilities 72 // X, X+N and X+Y+N in the beam under appropriate conditions of the previous 73 // node, and constrain what can follow, to enforce the rules explained above. 74 // We therefore have 3 different types of node determined by what can follow: 75 enum NodeContinuation { 76 NC_ANYTHING, // This node used just its own score, so anything can follow. 77 NC_ONLY_DUP, // The current node combined another score with the score for 78 // itself, without a stand-alone duplicate before, so must be 79 // followed by a stand-alone duplicate. 80 NC_NO_DUP, // The current node combined another score with the score for 81 // itself, after a stand-alone, so can only be followed by 82 // something other than a duplicate of the current node. 83 NC_COUNT 84 }; 85 86 // Enum describing the top-n status of a code. 87 enum TopNState { 88 TN_TOP2, // Winner or 2nd. 89 TN_TOPN, // Runner up in top-n, but not 1st or 2nd. 90 TN_ALSO_RAN, // Not in the top-n. 91 TN_COUNT 92 }; 93 94 // Lattice element for Re-encode beam search. 95 struct RecodeNode { RecodeNodeRecodeNode96 RecodeNode() 97 : code(-1) 98 , unichar_id(INVALID_UNICHAR_ID) 99 , permuter(TOP_CHOICE_PERM) 100 , start_of_dawg(false) 101 , start_of_word(false) 102 , end_of_word(false) 103 , duplicate(false) 104 , certainty(0.0f) 105 , score(0.0f) 106 , prev(nullptr) 107 , dawgs(nullptr) 108 , code_hash(0) {} RecodeNodeRecodeNode109 RecodeNode(int c, int uni_id, PermuterType perm, bool dawg_start, bool word_start, bool end, 110 bool dup, float cert, float s, const RecodeNode *p, DawgPositionVector *d, 111 uint64_t hash) 112 : code(c) 113 , unichar_id(uni_id) 114 , permuter(perm) 115 , start_of_dawg(dawg_start) 116 , start_of_word(word_start) 117 , end_of_word(end) 118 , duplicate(dup) 119 , certainty(cert) 120 , score(s) 121 , prev(p) 122 , dawgs(d) 123 , code_hash(hash) {} 124 // NOTE: If we could use C++11, then this would be a move constructor. 125 // Instead we have copy constructor that does a move!! This is because we 126 // don't want to copy the whole DawgPositionVector each time, and true 127 // copying isn't necessary for this struct. It does get moved around a lot 128 // though inside the heap and during heap push, hence the move semantics. RecodeNodeRecodeNode129 RecodeNode(const RecodeNode &src) : dawgs(nullptr) { 130 *this = src; 131 ASSERT_HOST(src.dawgs == nullptr); 132 } 133 RecodeNode &operator=(const RecodeNode &src) { 134 delete dawgs; 135 memcpy(this, &src, sizeof(src)); 136 ((RecodeNode &)src).dawgs = nullptr; 137 return *this; 138 } ~RecodeNodeRecodeNode139 ~RecodeNode() { 140 delete dawgs; 141 } 142 // Prints details of the node. 143 void Print(int null_char, const UNICHARSET &unicharset, int depth) const; 144 145 // The re-encoded code here = index to network output. 146 int code; 147 // The decoded unichar_id is only valid for the final code of a sequence. 148 int unichar_id; 149 // The type of permuter active at this point. Intervals between start_of_word 150 // and end_of_word make valid words of type given by permuter where 151 // end_of_word is true. These aren't necessarily delimited by spaces. 152 PermuterType permuter; 153 // True if this is the initial dawg state. May be attached to a space or, 154 // in a non-space-delimited lang, the end of the previous word. 155 bool start_of_dawg; 156 // True if this is the first node in a dictionary word. 157 bool start_of_word; 158 // True if this represents a valid candidate end of word position. Does not 159 // necessarily mark the end of a word, since a word can be extended beyond a 160 // candidate end by a continuation, eg 'the' continues to 'these'. 161 bool end_of_word; 162 // True if this->code is a duplicate of prev->code. Some training modes 163 // allow the network to output duplicate characters and crush them with CTC, 164 // but that would mess up the dictionary search, so we just smash them 165 // together on the fly using the duplicate flag. 166 bool duplicate; 167 // Certainty (log prob) of (just) this position. 168 float certainty; 169 // Total certainty of the path to this position. 170 float score; 171 // The previous node in this chain. Borrowed pointer. 172 const RecodeNode *prev; 173 // The currently active dawgs at this position. Owned pointer. 174 DawgPositionVector *dawgs; 175 // A hash of all codes in the prefix and this->code as well. Used for 176 // duplicate path removal. 177 uint64_t code_hash; 178 }; 179 180 using RecodePair = KDPairInc<double, RecodeNode>; 181 using RecodeHeap = GenericHeap<RecodePair>; 182 183 // Class that holds the entire beam search for recognition of a text line. 184 class TESS_API RecodeBeamSearch { 185 public: 186 // Borrows the pointer, which is expected to survive until *this is deleted. 187 RecodeBeamSearch(const UnicharCompress &recoder, int null_char, bool simple_text, Dict *dict); 188 ~RecodeBeamSearch(); 189 190 // Decodes the set of network outputs, storing the lattice internally. 191 // If charset is not null, it enables detailed debugging of the beam search. 192 void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, 193 double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode = 0); 194 void Decode(const GENERIC_2D_ARRAY<float> &output, double dict_ratio, double cert_offset, 195 double worst_dict_cert, const UNICHARSET *charset); 196 197 void DecodeSecondaryBeams(const NetworkIO &output, double dict_ratio, double cert_offset, 198 double worst_dict_cert, const UNICHARSET *charset, 199 int lstm_choice_mode = 0); 200 201 // Returns the best path as labels/scores/xcoords similar to simple CTC. 202 void ExtractBestPathAsLabels(std::vector<int> *labels, std::vector<int> *xcoords) const; 203 // Returns the best path as unichar-ids/certs/ratings/xcoords skipping 204 // duplicates, nulls and intermediate parts. 205 void ExtractBestPathAsUnicharIds(bool debug, const UNICHARSET *unicharset, 206 std::vector<int> *unichar_ids, std::vector<float> *certs, 207 std::vector<float> *ratings, std::vector<int> *xcoords) const; 208 209 // Returns the best path as a set of WERD_RES. 210 void ExtractBestPathAsWords(const TBOX &line_box, float scale_factor, bool debug, 211 const UNICHARSET *unicharset, PointerVector<WERD_RES> *words, 212 int lstm_choice_mode = 0); 213 214 // Generates debug output of the content of the beams after a Decode. 215 void DebugBeams(const UNICHARSET &unicharset) const; 216 217 // Extract the best charakters from the current decode iteration and block 218 // those symbols for the next iteration. In contrast to tesseracts standard 219 // method to chose the best overall node chain, this methods looks at a short 220 // node chain segmented by the character boundaries and chooses the best 221 // option independent of the remaining node chain. 222 void extractSymbolChoices(const UNICHARSET *unicharset); 223 224 // Generates debug output of the content of the beams after a Decode. 225 void PrintBeam2(bool uids, int num_outputs, const UNICHARSET *charset, bool secondary) const; 226 // Segments the timestep bundle by the character_boundaries. 227 void segmentTimestepsByCharacters(); 228 std::vector<std::vector<std::pair<const char *, float>>> 229 // Unions the segmented timestep character bundles to one big bundle. 230 combineSegmentedTimesteps( 231 std::vector<std::vector<std::vector<std::pair<const char *, float>>>> *segmentedTimesteps); 232 // Stores the alternative characters of every timestep together with their 233 // probability. 234 std::vector<std::vector<std::pair<const char *, float>>> timesteps; 235 std::vector<std::vector<std::vector<std::pair<const char *, float>>>> segmentedTimesteps; 236 // Stores the character choices found in the ctc algorithm 237 std::vector<std::vector<std::pair<const char *, float>>> ctc_choices; 238 // Stores all unicharids which are excluded for future iterations 239 std::vector<std::unordered_set<int>> excludedUnichars; 240 // Stores the character boundaries regarding timesteps. 241 std::vector<int> character_boundaries_; 242 // Clipping value for certainty inside Tesseract. Reflects the minimum value 243 // of certainty that will be returned by ExtractBestPathAsUnicharIds. 244 // Supposedly on a uniform scale that can be compared across languages and 245 // engines. 246 static constexpr float kMinCertainty = -20.0f; 247 // Number of different code lengths for which we have a separate beam. 248 static const int kNumLengths = RecodedCharID::kMaxCodeLen + 1; 249 // Total number of beams: dawg/nodawg * number of NodeContinuation * number 250 // of different lengths. 251 static const int kNumBeams = 2 * NC_COUNT * kNumLengths; 252 // Returns the relevant factor in the beams_ index. LengthFromBeamsIndex(int index)253 static int LengthFromBeamsIndex(int index) { 254 return index % kNumLengths; 255 } ContinuationFromBeamsIndex(int index)256 static NodeContinuation ContinuationFromBeamsIndex(int index) { 257 return static_cast<NodeContinuation>((index / kNumLengths) % NC_COUNT); 258 } IsDawgFromBeamsIndex(int index)259 static bool IsDawgFromBeamsIndex(int index) { 260 return index / (kNumLengths * NC_COUNT) > 0; 261 } 262 // Computes a beams_ index from the given factors. BeamIndex(bool is_dawg,NodeContinuation cont,int length)263 static int BeamIndex(bool is_dawg, NodeContinuation cont, int length) { 264 return (is_dawg * NC_COUNT + cont) * kNumLengths + length; 265 } 266 267 private: 268 // Struct for the Re-encode beam search. This struct holds the data for 269 // a single time-step position of the output. Use a vector<RecodeBeam> 270 // to hold all the timesteps and prevent reallocation of the individual heaps. 271 struct RecodeBeam { 272 // Resets to the initial state without deleting all the memory. ClearRecodeBeam273 void Clear() { 274 for (auto &beam : beams_) { 275 beam.clear(); 276 } 277 RecodeNode empty; 278 for (auto &best_initial_dawg : best_initial_dawgs_) { 279 best_initial_dawg = empty; 280 } 281 } 282 283 // A separate beam for each combination of code length, 284 // NodeContinuation, and dictionary flag. Separating out all these types 285 // allows the beam to be quite narrow, and yet still have a low chance of 286 // losing the best path. 287 // We have to keep all these beams separate, since the highest scoring paths 288 // come from the paths that are most likely to dead-end at any time, like 289 // dawg paths, NC_ONLY_DUP etc. 290 // Each heap is stored with the WORST result at the top, so we can quickly 291 // get the top-n values. 292 RecodeHeap beams_[kNumBeams]; 293 // While the language model is only a single word dictionary, we can use 294 // word starts as a choke point in the beam, and keep only a single dict 295 // start node at each step (for each NodeContinuation type), so we find the 296 // best one here and push it on the heap, if it qualifies, after processing 297 // all of the step. 298 RecodeNode best_initial_dawgs_[NC_COUNT]; 299 }; 300 using TopPair = KDPairInc<float, int>; 301 302 // Generates debug output of the content of a single beam position. 303 void DebugBeamPos(const UNICHARSET &unicharset, const RecodeHeap &heap) const; 304 305 // Returns the given best_nodes as unichar-ids/certs/ratings/xcoords skipping 306 // duplicates, nulls and intermediate parts. 307 static void ExtractPathAsUnicharIds(const std::vector<const RecodeNode *> &best_nodes, 308 std::vector<int> *unichar_ids, std::vector<float> *certs, 309 std::vector<float> *ratings, std::vector<int> *xcoords, 310 std::vector<int> *character_boundaries = nullptr); 311 312 // Sets up a word with the ratings matrix and fake blobs with boxes in the 313 // right places. 314 WERD_RES *InitializeWord(bool leading_space, const TBOX &line_box, int word_start, int word_end, 315 float space_certainty, const UNICHARSET *unicharset, 316 const std::vector<int> &xcoords, float scale_factor); 317 318 // Fills top_n_flags_ with bools that are true iff the corresponding output 319 // is one of the top_n. 320 void ComputeTopN(const float *outputs, int num_outputs, int top_n); 321 322 void ComputeSecTopN(std::unordered_set<int> *exList, const float *outputs, int num_outputs, 323 int top_n); 324 325 // Adds the computation for the current time-step to the beam. Call at each 326 // time-step in sequence from left to right. outputs is the activation vector 327 // for the current timestep. 328 void DecodeStep(const float *outputs, int t, double dict_ratio, double cert_offset, 329 double worst_dict_cert, const UNICHARSET *charset, bool debug = false); 330 331 void DecodeSecondaryStep(const float *outputs, int t, double dict_ratio, double cert_offset, 332 double worst_dict_cert, const UNICHARSET *charset, bool debug = false); 333 334 // Saves the most certain choices for the current time-step. 335 void SaveMostCertainChoices(const float *outputs, int num_outputs, const UNICHARSET *charset, 336 int xCoord); 337 338 // Calculates more accurate character boundaries which can be used to 339 // provide more accurate alternative symbol choices. 340 static void calculateCharBoundaries(std::vector<int> *starts, std::vector<int> *ends, 341 std::vector<int> *character_boundaries_, int maxWidth); 342 343 // Adds to the appropriate beams the legal (according to recoder) 344 // continuations of context prev, which is from the given index to beams_, 345 // using the given network outputs to provide scores to the choices. Uses only 346 // those choices for which top_n_flags[code] == top_n_flag. 347 void ContinueContext(const RecodeNode *prev, int index, const float *outputs, 348 TopNState top_n_flag, const UNICHARSET *unicharset, double dict_ratio, 349 double cert_offset, double worst_dict_cert, RecodeBeam *step); 350 // Continues for a new unichar, using dawg or non-dawg as per flag. 351 void ContinueUnichar(int code, int unichar_id, float cert, float worst_dict_cert, 352 float dict_ratio, bool use_dawgs, NodeContinuation cont, 353 const RecodeNode *prev, RecodeBeam *step); 354 // Adds a RecodeNode composed of the args to the correct heap in step if 355 // unichar_id is a valid dictionary continuation of whatever is in prev. 356 void ContinueDawg(int code, int unichar_id, float cert, NodeContinuation cont, 357 const RecodeNode *prev, RecodeBeam *step); 358 // Sets the correct best_initial_dawgs_ with a RecodeNode composed of the args 359 // if better than what is already there. 360 void PushInitialDawgIfBetter(int code, int unichar_id, PermuterType permuter, bool start, 361 bool end, float cert, NodeContinuation cont, const RecodeNode *prev, 362 RecodeBeam *step); 363 // Adds a RecodeNode composed of the args to the correct heap in step for 364 // partial unichar or duplicate if there is room or if better than the 365 // current worst element if already full. 366 void PushDupOrNoDawgIfBetter(int length, bool dup, int code, int unichar_id, float cert, 367 float worst_dict_cert, float dict_ratio, bool use_dawgs, 368 NodeContinuation cont, const RecodeNode *prev, RecodeBeam *step); 369 // Adds a RecodeNode composed of the args to the correct heap in step if there 370 // is room or if better than the current worst element if already full. 371 void PushHeapIfBetter(int max_size, int code, int unichar_id, PermuterType permuter, 372 bool dawg_start, bool word_start, bool end, bool dup, float cert, 373 const RecodeNode *prev, DawgPositionVector *d, RecodeHeap *heap); 374 // Adds a RecodeNode to heap if there is room 375 // or if better than the current worst element if already full. 376 void PushHeapIfBetter(int max_size, RecodeNode *node, RecodeHeap *heap); 377 // Searches the heap for an entry matching new_node, and updates the entry 378 // with reshuffle if needed. Returns true if there was a match. 379 bool UpdateHeapIfMatched(RecodeNode *new_node, RecodeHeap *heap); 380 // Computes and returns the code-hash for the given code and prev. 381 uint64_t ComputeCodeHash(int code, bool dup, const RecodeNode *prev) const; 382 // Backtracks to extract the best path through the lattice that was built 383 // during Decode. On return the best_nodes vector essentially contains the set 384 // of code, score pairs that make the optimal path with the constraint that 385 // the recoder can decode the code sequence back to a sequence of unichar-ids. 386 void ExtractBestPaths(std::vector<const RecodeNode *> *best_nodes, 387 std::vector<const RecodeNode *> *second_nodes) const; 388 // Helper backtracks through the lattice from the given node, storing the 389 // path and reversing it. 390 void ExtractPath(const RecodeNode *node, std::vector<const RecodeNode *> *path) const; 391 void ExtractPath(const RecodeNode *node, std::vector<const RecodeNode *> *path, 392 int limiter) const; 393 // Helper prints debug information on the given lattice path. 394 void DebugPath(const UNICHARSET *unicharset, const std::vector<const RecodeNode *> &path) const; 395 // Helper prints debug information on the given unichar path. 396 void DebugUnicharPath(const UNICHARSET *unicharset, const std::vector<const RecodeNode *> &path, 397 const std::vector<int> &unichar_ids, const std::vector<float> &certs, 398 const std::vector<float> &ratings, const std::vector<int> &xcoords) const; 399 400 static const int kBeamWidths[RecodedCharID::kMaxCodeLen + 1]; 401 402 // The encoder/decoder that we will be using. 403 const UnicharCompress &recoder_; 404 // The beam for each timestep in the output. 405 std::vector<RecodeBeam *> beam_; 406 // Secondary Beam for Results with less Probability 407 std::vector<RecodeBeam *> secondary_beam_; 408 // The number of timesteps valid in beam_; 409 int beam_size_; 410 // A flag to indicate which outputs are the top-n choices. Current timestep 411 // only. 412 std::vector<TopNState> top_n_flags_; 413 // A record of the highest and second scoring codes. 414 int top_code_; 415 int second_code_; 416 // Heap used to compute the top_n_flags_. 417 GenericHeap<TopPair> top_heap_; 418 // Borrowed pointer to the dictionary to use in the search. 419 Dict *dict_; 420 // True if the language is space-delimited, which is true for most languages 421 // except chi*, jpn, tha. 422 bool space_delimited_; 423 // True if the input is simple text, ie adjacent equal chars are not to be 424 // eliminated. 425 bool is_simple_text_; 426 // The encoded (class label) of the null/reject character. 427 int null_char_; 428 }; 429 430 } // namespace tesseract. 431 432 #endif // THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_ 433