1 ///////////////////////////////////////////////////////////////////////
2 // File:        recodebeam.h
3 // Description: Beam search to decode from the re-encoded CJK as a sequence of
4 //              smaller numbers in place of a single large code.
5 // Author:      Ray Smith
6 //
7 // (C) Copyright 2015, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
17 //
18 ///////////////////////////////////////////////////////////////////////
19 
20 #ifndef THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_
21 #define THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_
22 
23 #include "dawg.h"
24 #include "dict.h"
25 #include "genericheap.h"
26 #include "genericvector.h"
27 #include "kdpair.h"
28 #include "networkio.h"
29 #include "ratngs.h"
30 #include "unicharcompress.h"
31 
32 #include <deque>
33 #include <set>
34 #include <tuple>
35 #include <unordered_set>
36 #include <vector>
37 
38 namespace tesseract {
39 
40 // Enum describing what can follow the current node.
41 // Consider the following softmax outputs:
42 // Timestep    0    1    2    3    4    5    6    7    8
43 // X-score    0.01 0.55 0.98 0.42 0.01 0.01 0.40 0.95 0.01
44 // Y-score    0.00 0.01 0.01 0.01 0.01 0.97 0.59 0.04 0.01
45 // Null-score 0.99 0.44 0.01 0.57 0.98 0.02 0.01 0.01 0.98
46 // Then the correct CTC decoding (in which adjacent equal classes are folded,
47 // and then all nulls are dropped) is clearly XYX, but simple decoding (taking
48 // the max at each timestep) leads to:
49 // Null@0.99 X@0.55 X@0.98 Null@0.57 Null@0.98 Y@0.97 Y@0.59 X@0.95 Null@0.98,
50 // which folds to the correct XYX. The conversion to Tesseract rating and
51 // certainty uses the sum of the log probs (log of the product of probabilities)
52 // for the Rating and the minimum log prob for the certainty, but that yields a
53 // minimum certainty of log(0.55), which is poor for such an obvious case.
54 // CTC says that the probability of the result is the SUM of the products of the
55 // probabilities over ALL PATHS that decode to the same result, which includes:
56 // NXXNNYYXN, NNXNNYYN, NXXXNYYXN, NNXXNYXXN, and others including XXXXXYYXX.
57 // That is intractable, so some compromise between simple and ideal is needed.
58 // Observing that evenly split timesteps rarely happen next to each other, we
59 // allow scores at a transition between classes to be added for decoding thus:
60 // N@0.99 (N+X)@0.99 X@0.98 (N+X)@0.99 N@0.98 Y@0.97 (X+Y+N)@1.00 X@0.95 N@0.98.
61 // This works because NNX and NXX both decode to X, so in the middle we can use
62 // N+X. Note that the classes either side of a sum must stand alone, i.e. use a
63 // single score, to force all paths to pass through them and decode to the same
64 // result. Also in the special case of a transition from X to Y, with only one
65 // timestep between, it is possible to add X+Y+N, since XXY, XYY, and XNY all
66 // decode to XY.
67 // An important condition is that we cannot combine X and Null between two
68 // stand-alone Xs, since that can decode as XNX->XX or XXX->X, so the scores for
69 // X and Null have to go in separate paths. Combining scores in this way
70 // provides a much better minimum certainty of log(0.95).
71 // In the implementation of the beam search, we have to place the possibilities
72 // X, X+N and X+Y+N in the beam under appropriate conditions of the previous
73 // node, and constrain what can follow, to enforce the rules explained above.
74 // We therefore have 3 different types of node determined by what can follow:
75 enum NodeContinuation {
76   NC_ANYTHING, // This node used just its own score, so anything can follow.
77   NC_ONLY_DUP, // The current node combined another score with the score for
78                // itself, without a stand-alone duplicate before, so must be
79                // followed by a stand-alone duplicate.
80   NC_NO_DUP,   // The current node combined another score with the score for
81                // itself, after a stand-alone, so can only be followed by
82                // something other than a duplicate of the current node.
83   NC_COUNT
84 };
85 
86 // Enum describing the top-n status of a code.
87 enum TopNState {
88   TN_TOP2,     // Winner or 2nd.
89   TN_TOPN,     // Runner up in top-n, but not 1st or 2nd.
90   TN_ALSO_RAN, // Not in the top-n.
91   TN_COUNT
92 };
93 
94 // Lattice element for Re-encode beam search.
95 struct RecodeNode {
RecodeNodeRecodeNode96   RecodeNode()
97       : code(-1)
98       , unichar_id(INVALID_UNICHAR_ID)
99       , permuter(TOP_CHOICE_PERM)
100       , start_of_dawg(false)
101       , start_of_word(false)
102       , end_of_word(false)
103       , duplicate(false)
104       , certainty(0.0f)
105       , score(0.0f)
106       , prev(nullptr)
107       , dawgs(nullptr)
108       , code_hash(0) {}
RecodeNodeRecodeNode109   RecodeNode(int c, int uni_id, PermuterType perm, bool dawg_start, bool word_start, bool end,
110              bool dup, float cert, float s, const RecodeNode *p, DawgPositionVector *d,
111              uint64_t hash)
112       : code(c)
113       , unichar_id(uni_id)
114       , permuter(perm)
115       , start_of_dawg(dawg_start)
116       , start_of_word(word_start)
117       , end_of_word(end)
118       , duplicate(dup)
119       , certainty(cert)
120       , score(s)
121       , prev(p)
122       , dawgs(d)
123       , code_hash(hash) {}
124   // NOTE: If we could use C++11, then this would be a move constructor.
125   // Instead we have copy constructor that does a move!! This is because we
126   // don't want to copy the whole DawgPositionVector each time, and true
127   // copying isn't necessary for this struct. It does get moved around a lot
128   // though inside the heap and during heap push, hence the move semantics.
RecodeNodeRecodeNode129   RecodeNode(const RecodeNode &src) : dawgs(nullptr) {
130     *this = src;
131     ASSERT_HOST(src.dawgs == nullptr);
132   }
133   RecodeNode &operator=(const RecodeNode &src) {
134     delete dawgs;
135     memcpy(this, &src, sizeof(src));
136     ((RecodeNode &)src).dawgs = nullptr;
137     return *this;
138   }
~RecodeNodeRecodeNode139   ~RecodeNode() {
140     delete dawgs;
141   }
142   // Prints details of the node.
143   void Print(int null_char, const UNICHARSET &unicharset, int depth) const;
144 
145   // The re-encoded code here = index to network output.
146   int code;
147   // The decoded unichar_id is only valid for the final code of a sequence.
148   int unichar_id;
149   // The type of permuter active at this point. Intervals between start_of_word
150   // and end_of_word make valid words of type given by permuter where
151   // end_of_word is true. These aren't necessarily delimited by spaces.
152   PermuterType permuter;
153   // True if this is the initial dawg state. May be attached to a space or,
154   // in a non-space-delimited lang, the end of the previous word.
155   bool start_of_dawg;
156   // True if this is the first node in a dictionary word.
157   bool start_of_word;
158   // True if this represents a valid candidate end of word position. Does not
159   // necessarily mark the end of a word, since a word can be extended beyond a
160   // candidate end by a continuation, eg 'the' continues to 'these'.
161   bool end_of_word;
162   // True if this->code is a duplicate of prev->code. Some training modes
163   // allow the network to output duplicate characters and crush them with CTC,
164   // but that would mess up the dictionary search, so we just smash them
165   // together on the fly using the duplicate flag.
166   bool duplicate;
167   // Certainty (log prob) of (just) this position.
168   float certainty;
169   // Total certainty of the path to this position.
170   float score;
171   // The previous node in this chain. Borrowed pointer.
172   const RecodeNode *prev;
173   // The currently active dawgs at this position. Owned pointer.
174   DawgPositionVector *dawgs;
175   // A hash of all codes in the prefix and this->code as well. Used for
176   // duplicate path removal.
177   uint64_t code_hash;
178 };
179 
180 using RecodePair = KDPairInc<double, RecodeNode>;
181 using RecodeHeap = GenericHeap<RecodePair>;
182 
183 // Class that holds the entire beam search for recognition of a text line.
184 class TESS_API RecodeBeamSearch {
185 public:
186   // Borrows the pointer, which is expected to survive until *this is deleted.
187   RecodeBeamSearch(const UnicharCompress &recoder, int null_char, bool simple_text, Dict *dict);
188   ~RecodeBeamSearch();
189 
190   // Decodes the set of network outputs, storing the lattice internally.
191   // If charset is not null, it enables detailed debugging of the beam search.
192   void Decode(const NetworkIO &output, double dict_ratio, double cert_offset,
193               double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode = 0);
194   void Decode(const GENERIC_2D_ARRAY<float> &output, double dict_ratio, double cert_offset,
195               double worst_dict_cert, const UNICHARSET *charset);
196 
197   void DecodeSecondaryBeams(const NetworkIO &output, double dict_ratio, double cert_offset,
198                             double worst_dict_cert, const UNICHARSET *charset,
199                             int lstm_choice_mode = 0);
200 
201   // Returns the best path as labels/scores/xcoords similar to simple CTC.
202   void ExtractBestPathAsLabels(std::vector<int> *labels, std::vector<int> *xcoords) const;
203   // Returns the best path as unichar-ids/certs/ratings/xcoords skipping
204   // duplicates, nulls and intermediate parts.
205   void ExtractBestPathAsUnicharIds(bool debug, const UNICHARSET *unicharset,
206                                    std::vector<int> *unichar_ids, std::vector<float> *certs,
207                                    std::vector<float> *ratings, std::vector<int> *xcoords) const;
208 
209   // Returns the best path as a set of WERD_RES.
210   void ExtractBestPathAsWords(const TBOX &line_box, float scale_factor, bool debug,
211                               const UNICHARSET *unicharset, PointerVector<WERD_RES> *words,
212                               int lstm_choice_mode = 0);
213 
214   // Generates debug output of the content of the beams after a Decode.
215   void DebugBeams(const UNICHARSET &unicharset) const;
216 
217   // Extract the best charakters from the current decode iteration and block
218   // those symbols for the next iteration. In contrast to tesseracts standard
219   // method to chose the best overall node chain, this methods looks at a short
220   // node chain segmented by the character boundaries and chooses the best
221   // option independent of the remaining node chain.
222   void extractSymbolChoices(const UNICHARSET *unicharset);
223 
224   // Generates debug output of the content of the beams after a Decode.
225   void PrintBeam2(bool uids, int num_outputs, const UNICHARSET *charset, bool secondary) const;
226   // Segments the timestep bundle by the character_boundaries.
227   void segmentTimestepsByCharacters();
228   std::vector<std::vector<std::pair<const char *, float>>>
229   // Unions the segmented timestep character bundles to one big bundle.
230   combineSegmentedTimesteps(
231       std::vector<std::vector<std::vector<std::pair<const char *, float>>>> *segmentedTimesteps);
232   // Stores the alternative characters of every timestep together with their
233   // probability.
234   std::vector<std::vector<std::pair<const char *, float>>> timesteps;
235   std::vector<std::vector<std::vector<std::pair<const char *, float>>>> segmentedTimesteps;
236   // Stores the character choices found in the ctc algorithm
237   std::vector<std::vector<std::pair<const char *, float>>> ctc_choices;
238   // Stores all unicharids which are excluded for future iterations
239   std::vector<std::unordered_set<int>> excludedUnichars;
240   // Stores the character boundaries regarding timesteps.
241   std::vector<int> character_boundaries_;
242   // Clipping value for certainty inside Tesseract. Reflects the minimum value
243   // of certainty that will be returned by ExtractBestPathAsUnicharIds.
244   // Supposedly on a uniform scale that can be compared across languages and
245   // engines.
246   static constexpr float kMinCertainty = -20.0f;
247   // Number of different code lengths for which we have a separate beam.
248   static const int kNumLengths = RecodedCharID::kMaxCodeLen + 1;
249   // Total number of beams: dawg/nodawg * number of NodeContinuation * number
250   // of different lengths.
251   static const int kNumBeams = 2 * NC_COUNT * kNumLengths;
252   // Returns the relevant factor in the beams_ index.
LengthFromBeamsIndex(int index)253   static int LengthFromBeamsIndex(int index) {
254     return index % kNumLengths;
255   }
ContinuationFromBeamsIndex(int index)256   static NodeContinuation ContinuationFromBeamsIndex(int index) {
257     return static_cast<NodeContinuation>((index / kNumLengths) % NC_COUNT);
258   }
IsDawgFromBeamsIndex(int index)259   static bool IsDawgFromBeamsIndex(int index) {
260     return index / (kNumLengths * NC_COUNT) > 0;
261   }
262   // Computes a beams_ index from the given factors.
BeamIndex(bool is_dawg,NodeContinuation cont,int length)263   static int BeamIndex(bool is_dawg, NodeContinuation cont, int length) {
264     return (is_dawg * NC_COUNT + cont) * kNumLengths + length;
265   }
266 
267 private:
268   // Struct for the Re-encode beam search. This struct holds the data for
269   // a single time-step position of the output. Use a vector<RecodeBeam>
270   // to hold all the timesteps and prevent reallocation of the individual heaps.
271   struct RecodeBeam {
272     // Resets to the initial state without deleting all the memory.
ClearRecodeBeam273     void Clear() {
274       for (auto &beam : beams_) {
275         beam.clear();
276       }
277       RecodeNode empty;
278       for (auto &best_initial_dawg : best_initial_dawgs_) {
279         best_initial_dawg = empty;
280       }
281     }
282 
283     // A separate beam for each combination of code length,
284     // NodeContinuation, and dictionary flag. Separating out all these types
285     // allows the beam to be quite narrow, and yet still have a low chance of
286     // losing the best path.
287     // We have to keep all these beams separate, since the highest scoring paths
288     // come from the paths that are most likely to dead-end at any time, like
289     // dawg paths, NC_ONLY_DUP etc.
290     // Each heap is stored with the WORST result at the top, so we can quickly
291     // get the top-n values.
292     RecodeHeap beams_[kNumBeams];
293     // While the language model is only a single word dictionary, we can use
294     // word starts as a choke point in the beam, and keep only a single dict
295     // start node at each step (for each NodeContinuation type), so we find the
296     // best one here and push it on the heap, if it qualifies, after processing
297     // all of the step.
298     RecodeNode best_initial_dawgs_[NC_COUNT];
299   };
300   using TopPair = KDPairInc<float, int>;
301 
302   // Generates debug output of the content of a single beam position.
303   void DebugBeamPos(const UNICHARSET &unicharset, const RecodeHeap &heap) const;
304 
305   // Returns the given best_nodes as unichar-ids/certs/ratings/xcoords skipping
306   // duplicates, nulls and intermediate parts.
307   static void ExtractPathAsUnicharIds(const std::vector<const RecodeNode *> &best_nodes,
308                                       std::vector<int> *unichar_ids, std::vector<float> *certs,
309                                       std::vector<float> *ratings, std::vector<int> *xcoords,
310                                       std::vector<int> *character_boundaries = nullptr);
311 
312   // Sets up a word with the ratings matrix and fake blobs with boxes in the
313   // right places.
314   WERD_RES *InitializeWord(bool leading_space, const TBOX &line_box, int word_start, int word_end,
315                            float space_certainty, const UNICHARSET *unicharset,
316                            const std::vector<int> &xcoords, float scale_factor);
317 
318   // Fills top_n_flags_ with bools that are true iff the corresponding output
319   // is one of the top_n.
320   void ComputeTopN(const float *outputs, int num_outputs, int top_n);
321 
322   void ComputeSecTopN(std::unordered_set<int> *exList, const float *outputs, int num_outputs,
323                       int top_n);
324 
325   // Adds the computation for the current time-step to the beam. Call at each
326   // time-step in sequence from left to right. outputs is the activation vector
327   // for the current timestep.
328   void DecodeStep(const float *outputs, int t, double dict_ratio, double cert_offset,
329                   double worst_dict_cert, const UNICHARSET *charset, bool debug = false);
330 
331   void DecodeSecondaryStep(const float *outputs, int t, double dict_ratio, double cert_offset,
332                            double worst_dict_cert, const UNICHARSET *charset, bool debug = false);
333 
334   // Saves the most certain choices for the current time-step.
335   void SaveMostCertainChoices(const float *outputs, int num_outputs, const UNICHARSET *charset,
336                               int xCoord);
337 
338   // Calculates more accurate character boundaries which can be used to
339   // provide more accurate alternative symbol choices.
340   static void calculateCharBoundaries(std::vector<int> *starts, std::vector<int> *ends,
341                                       std::vector<int> *character_boundaries_, int maxWidth);
342 
343   // Adds to the appropriate beams the legal (according to recoder)
344   // continuations of context prev, which is from the given index to beams_,
345   // using the given network outputs to provide scores to the choices. Uses only
346   // those choices for which top_n_flags[code] == top_n_flag.
347   void ContinueContext(const RecodeNode *prev, int index, const float *outputs,
348                        TopNState top_n_flag, const UNICHARSET *unicharset, double dict_ratio,
349                        double cert_offset, double worst_dict_cert, RecodeBeam *step);
350   // Continues for a new unichar, using dawg or non-dawg as per flag.
351   void ContinueUnichar(int code, int unichar_id, float cert, float worst_dict_cert,
352                        float dict_ratio, bool use_dawgs, NodeContinuation cont,
353                        const RecodeNode *prev, RecodeBeam *step);
354   // Adds a RecodeNode composed of the args to the correct heap in step if
355   // unichar_id is a valid dictionary continuation of whatever is in prev.
356   void ContinueDawg(int code, int unichar_id, float cert, NodeContinuation cont,
357                     const RecodeNode *prev, RecodeBeam *step);
358   // Sets the correct best_initial_dawgs_ with a RecodeNode composed of the args
359   // if better than what is already there.
360   void PushInitialDawgIfBetter(int code, int unichar_id, PermuterType permuter, bool start,
361                                bool end, float cert, NodeContinuation cont, const RecodeNode *prev,
362                                RecodeBeam *step);
363   // Adds a RecodeNode composed of the args to the correct heap in step for
364   // partial unichar or duplicate if there is room or if better than the
365   // current worst element if already full.
366   void PushDupOrNoDawgIfBetter(int length, bool dup, int code, int unichar_id, float cert,
367                                float worst_dict_cert, float dict_ratio, bool use_dawgs,
368                                NodeContinuation cont, const RecodeNode *prev, RecodeBeam *step);
369   // Adds a RecodeNode composed of the args to the correct heap in step if there
370   // is room or if better than the current worst element if already full.
371   void PushHeapIfBetter(int max_size, int code, int unichar_id, PermuterType permuter,
372                         bool dawg_start, bool word_start, bool end, bool dup, float cert,
373                         const RecodeNode *prev, DawgPositionVector *d, RecodeHeap *heap);
374   // Adds a RecodeNode to heap if there is room
375   // or if better than the current worst element if already full.
376   void PushHeapIfBetter(int max_size, RecodeNode *node, RecodeHeap *heap);
377   // Searches the heap for an entry matching new_node, and updates the entry
378   // with reshuffle if needed. Returns true if there was a match.
379   bool UpdateHeapIfMatched(RecodeNode *new_node, RecodeHeap *heap);
380   // Computes and returns the code-hash for the given code and prev.
381   uint64_t ComputeCodeHash(int code, bool dup, const RecodeNode *prev) const;
382   // Backtracks to extract the best path through the lattice that was built
383   // during Decode. On return the best_nodes vector essentially contains the set
384   // of code, score pairs that make the optimal path with the constraint that
385   // the recoder can decode the code sequence back to a sequence of unichar-ids.
386   void ExtractBestPaths(std::vector<const RecodeNode *> *best_nodes,
387                         std::vector<const RecodeNode *> *second_nodes) const;
388   // Helper backtracks through the lattice from the given node, storing the
389   // path and reversing it.
390   void ExtractPath(const RecodeNode *node, std::vector<const RecodeNode *> *path) const;
391   void ExtractPath(const RecodeNode *node, std::vector<const RecodeNode *> *path,
392                    int limiter) const;
393   // Helper prints debug information on the given lattice path.
394   void DebugPath(const UNICHARSET *unicharset, const std::vector<const RecodeNode *> &path) const;
395   // Helper prints debug information on the given unichar path.
396   void DebugUnicharPath(const UNICHARSET *unicharset, const std::vector<const RecodeNode *> &path,
397                         const std::vector<int> &unichar_ids, const std::vector<float> &certs,
398                         const std::vector<float> &ratings, const std::vector<int> &xcoords) const;
399 
400   static const int kBeamWidths[RecodedCharID::kMaxCodeLen + 1];
401 
402   // The encoder/decoder that we will be using.
403   const UnicharCompress &recoder_;
404   // The beam for each timestep in the output.
405   std::vector<RecodeBeam *> beam_;
406   // Secondary Beam for Results with less Probability
407   std::vector<RecodeBeam *> secondary_beam_;
408   // The number of timesteps valid in beam_;
409   int beam_size_;
410   // A flag to indicate which outputs are the top-n choices. Current timestep
411   // only.
412   std::vector<TopNState> top_n_flags_;
413   // A record of the highest and second scoring codes.
414   int top_code_;
415   int second_code_;
416   // Heap used to compute the top_n_flags_.
417   GenericHeap<TopPair> top_heap_;
418   // Borrowed pointer to the dictionary to use in the search.
419   Dict *dict_;
420   // True if the language is space-delimited, which is true for most languages
421   // except chi*, jpn, tha.
422   bool space_delimited_;
423   // True if the input is simple text, ie adjacent equal chars are not to be
424   // eliminated.
425   bool is_simple_text_;
426   // The encoded (class label) of the null/reject character.
427   int null_char_;
428 };
429 
430 } // namespace tesseract.
431 
432 #endif // THIRD_PARTY_TESSERACT_LSTM_RECODEBEAM_H_
433