1 ///////////////////////////////////////////////////////////////////////
2 // File:        lstmrecognizer.h
3 // Description: Top-level line recognizer class for LSTM-based networks.
4 // Author:      Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 ///////////////////////////////////////////////////////////////////////
17 
18 #ifndef TESSERACT_LSTM_LSTMRECOGNIZER_H_
19 #define TESSERACT_LSTM_LSTMRECOGNIZER_H_
20 
21 #include "ccutil.h"
22 #include "helpers.h"
23 #include "matrix.h"
24 #include "network.h"
25 #include "networkscratch.h"
26 #include "params.h"
27 #include "recodebeam.h"
28 #include "series.h"
29 #include "unicharcompress.h"
30 
31 class BLOB_CHOICE_IT;
32 struct Pix;
33 class ROW_RES;
34 class ScrollView;
35 class TBOX;
36 class WERD_RES;
37 
38 namespace tesseract {
39 
40 class Dict;
41 class ImageData;
42 
43 // Enum indicating training mode control flags.
44 enum TrainingFlags {
45   TF_INT_MODE = 1,
46   TF_COMPRESS_UNICHARSET = 64,
47 };
48 
49 // Top-level line recognizer class for LSTM-based networks.
50 // Note that a sub-class, LSTMTrainer is used for training.
51 class TESS_API LSTMRecognizer {
52 public:
53   LSTMRecognizer();
54   LSTMRecognizer(const std::string &language_data_path_prefix);
55   ~LSTMRecognizer();
56 
NumOutputs()57   int NumOutputs() const {
58     return network_->NumOutputs();
59   }
60 
61   // Return the training iterations.
training_iteration()62   int training_iteration() const {
63     return training_iteration_;
64   }
65 
66   // Return the sample iterations.
sample_iteration()67   int sample_iteration() const {
68     return sample_iteration_;
69   }
70 
71   // Return the learning rate.
learning_rate()72   float learning_rate() const {
73     return learning_rate_;
74   }
75 
OutputLossType()76   LossType OutputLossType() const {
77     if (network_ == nullptr) {
78       return LT_NONE;
79     }
80     StaticShape shape;
81     shape = network_->OutputShape(shape);
82     return shape.loss_type();
83   }
SimpleTextOutput()84   bool SimpleTextOutput() const {
85     return OutputLossType() == LT_SOFTMAX;
86   }
IsIntMode()87   bool IsIntMode() const {
88     return (training_flags_ & TF_INT_MODE) != 0;
89   }
90   // True if recoder_ is active to re-encode text to a smaller space.
IsRecoding()91   bool IsRecoding() const {
92     return (training_flags_ & TF_COMPRESS_UNICHARSET) != 0;
93   }
94   // Returns true if the network is a TensorFlow network.
IsTensorFlow()95   bool IsTensorFlow() const {
96     return network_->type() == NT_TENSORFLOW;
97   }
98   // Returns a vector of layer ids that can be passed to other layer functions
99   // to access a specific layer.
EnumerateLayers()100   std::vector<std::string> EnumerateLayers() const {
101     ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
102     auto *series = static_cast<Series *>(network_);
103     std::vector<std::string> layers;
104     series->EnumerateLayers(nullptr, layers);
105     return layers;
106   }
107   // Returns a specific layer from its id (from EnumerateLayers).
GetLayer(const std::string & id)108   Network *GetLayer(const std::string &id) const {
109     ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
110     ASSERT_HOST(id.length() > 1 && id[0] == ':');
111     auto *series = static_cast<Series *>(network_);
112     return series->GetLayer(&id[1]);
113   }
114   // Returns the learning rate of the layer from its id.
GetLayerLearningRate(const std::string & id)115   float GetLayerLearningRate(const std::string &id) const {
116     ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
117     if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
118       ASSERT_HOST(id.length() > 1 && id[0] == ':');
119       auto *series = static_cast<Series *>(network_);
120       return series->LayerLearningRate(&id[1]);
121     } else {
122       return learning_rate_;
123     }
124   }
125 
126   // Return the network string.
GetNetwork()127   const char *GetNetwork() const {
128     return network_str_.c_str();
129   }
130 
131   // Return the adam beta.
GetAdamBeta()132   float GetAdamBeta() const {
133     return adam_beta_;
134   }
135 
136   // Return the momentum.
GetMomentum()137   float GetMomentum() const {
138     return momentum_;
139   }
140 
141   // Multiplies the all the learning rate(s) by the given factor.
ScaleLearningRate(double factor)142   void ScaleLearningRate(double factor) {
143     ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
144     learning_rate_ *= factor;
145     if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
146       std::vector<std::string> layers = EnumerateLayers();
147       for (auto &layer : layers) {
148         ScaleLayerLearningRate(layer, factor);
149       }
150     }
151   }
152   // Multiplies the learning rate of the layer with id, by the given factor.
ScaleLayerLearningRate(const std::string & id,double factor)153   void ScaleLayerLearningRate(const std::string &id, double factor) {
154     ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
155     ASSERT_HOST(id.length() > 1 && id[0] == ':');
156     auto *series = static_cast<Series *>(network_);
157     series->ScaleLayerLearningRate(&id[1], factor);
158   }
159 
160   // Set the all the learning rate(s) to the given value.
SetLearningRate(float learning_rate)161   void SetLearningRate(float learning_rate)
162   {
163     ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
164     learning_rate_ = learning_rate;
165     if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
166       for (auto &id : EnumerateLayers()) {
167         SetLayerLearningRate(id, learning_rate);
168       }
169     }
170   }
171   // Set the learning rate of the layer with id, by the given value.
SetLayerLearningRate(const std::string & id,float learning_rate)172   void SetLayerLearningRate(const std::string &id, float learning_rate)
173   {
174     ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
175     ASSERT_HOST(id.length() > 1 && id[0] == ':');
176     auto *series = static_cast<Series *>(network_);
177     series->SetLayerLearningRate(&id[1], learning_rate);
178   }
179 
180   // Converts the network to int if not already.
ConvertToInt()181   void ConvertToInt() {
182     if ((training_flags_ & TF_INT_MODE) == 0) {
183       network_->ConvertToInt();
184       training_flags_ |= TF_INT_MODE;
185     }
186   }
187 
188   // Provides access to the UNICHARSET that this classifier works with.
GetUnicharset()189   const UNICHARSET &GetUnicharset() const {
190     return ccutil_.unicharset;
191   }
GetUnicharset()192   UNICHARSET &GetUnicharset() {
193     return ccutil_.unicharset;
194   }
195   // Provides access to the UnicharCompress that this classifier works with.
GetRecoder()196   const UnicharCompress &GetRecoder() const {
197     return recoder_;
198   }
199   // Provides access to the Dict that this classifier works with.
GetDict()200   const Dict *GetDict() const {
201     return dict_;
202   }
GetDict()203   Dict *GetDict() {
204     return dict_;
205   }
206   // Sets the sample iteration to the given value. The sample_iteration_
207   // determines the seed for the random number generator. The training
208   // iteration is incremented only by a successful training iteration.
SetIteration(int iteration)209   void SetIteration(int iteration) {
210     sample_iteration_ = iteration;
211   }
212   // Accessors for textline image normalization.
NumInputs()213   int NumInputs() const {
214     return network_->NumInputs();
215   }
216 
217   // Return the null char index.
null_char()218   int null_char() const {
219     return null_char_;
220   }
221 
222   // Loads a model from mgr, including the dictionary only if lang is not null.
223   bool Load(const ParamsVectors *params, const std::string &lang, TessdataManager *mgr);
224 
225   // Writes to the given file. Returns false in case of error.
226   // If mgr contains a unicharset and recoder, then they are not encoded to fp.
227   bool Serialize(const TessdataManager *mgr, TFile *fp) const;
228   // Reads from the given file. Returns false in case of error.
229   // If mgr contains a unicharset and recoder, then they are taken from there,
230   // otherwise, they are part of the serialization in fp.
231   bool DeSerialize(const TessdataManager *mgr, TFile *fp);
232   // Loads the charsets from mgr.
233   bool LoadCharsets(const TessdataManager *mgr);
234   // Loads the Recoder.
235   bool LoadRecoder(TFile *fp);
236   // Loads the dictionary if possible from the traineddata file.
237   // Prints a warning message, and returns false but otherwise fails silently
238   // and continues to work without it if loading fails.
239   // Note that dictionary load is independent from DeSerialize, but dependent
240   // on the unicharset matching. This enables training to deserialize a model
241   // from checkpoint or restore without having to go back and reload the
242   // dictionary.
243   bool LoadDictionary(const ParamsVectors *params, const std::string &lang, TessdataManager *mgr);
244 
245   // Recognizes the line image, contained within image_data, returning the
246   // recognized tesseract WERD_RES for the words.
247   // If invert, tries inverted as well if the normal interpretation doesn't
248   // produce a good enough result. The line_box is used for computing the
249   // box_word in the output words. worst_dict_cert is the worst certainty that
250   // will be used in a dictionary word.
251   void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert,
252                      const TBOX &line_box, PointerVector<WERD_RES> *words, int lstm_choice_mode = 0,
253                      int lstm_choice_amount = 5);
254 
255   // Helper computes min and mean best results in the output.
256   void OutputStats(const NetworkIO &outputs, float *min_output, float *mean_output, float *sd);
257   // Recognizes the image_data, returning the labels,
258   // scores, and corresponding pairs of start, end x-coords in coords.
259   // Returned in scale_factor is the reduction factor
260   // between the image and the output coords, for computing bounding boxes.
261   // If re_invert is true, the input is inverted back to its original
262   // photometric interpretation if inversion is attempted but fails to
263   // improve the results. This ensures that outputs contains the correct
264   // forward outputs for the best photometric interpretation.
265   // inputs is filled with the used inputs to the network.
266   bool RecognizeLine(const ImageData &image_data, bool invert, bool debug, bool re_invert,
267                      bool upside_down, float *scale_factor, NetworkIO *inputs, NetworkIO *outputs);
268 
269   // Converts an array of labels to utf-8, whether or not the labels are
270   // augmented with character boundaries.
271   std::string DecodeLabels(const std::vector<int> &labels);
272 
273   // Displays the forward results in a window with the characters and
274   // boundaries as determined by the labels and label_coords.
275   void DisplayForward(const NetworkIO &inputs, const std::vector<int> &labels,
276                       const std::vector<int> &label_coords, const char *window_name,
277                       ScrollView **window);
278   // Converts the network output to a sequence of labels. Outputs labels, scores
279   // and start xcoords of each char, and each null_char_, with an additional
280   // final xcoord for the end of the output.
281   // The conversion method is determined by internal state.
282   void LabelsFromOutputs(const NetworkIO &outputs, std::vector<int> *labels,
283                          std::vector<int> *xcoords);
284 
285 protected:
286   // Sets the random seed from the sample_iteration_;
SetRandomSeed()287   void SetRandomSeed() {
288     int64_t seed = static_cast<int64_t>(sample_iteration_) * 0x10000001;
289     randomizer_.set_seed(seed);
290     randomizer_.IntRand();
291   }
292 
293   // Displays the labels and cuts at the corresponding xcoords.
294   // Size of labels should match xcoords.
295   void DisplayLSTMOutput(const std::vector<int> &labels, const std::vector<int> &xcoords,
296                          int height, ScrollView *window);
297 
298   // Prints debug output detailing the activation path that is implied by the
299   // xcoords.
300   void DebugActivationPath(const NetworkIO &outputs, const std::vector<int> &labels,
301                            const std::vector<int> &xcoords);
302 
303   // Prints debug output detailing activations and 2nd choice over a range
304   // of positions.
305   void DebugActivationRange(const NetworkIO &outputs, const char *label, int best_choice,
306                             int x_start, int x_end);
307 
308   // As LabelsViaCTC except that this function constructs the best path that
309   // contains only legal sequences of subcodes for recoder_.
310   void LabelsViaReEncode(const NetworkIO &output, std::vector<int> *labels,
311                          std::vector<int> *xcoords);
312   // Converts the network output to a sequence of labels, with scores, using
313   // the simple character model (each position is a char, and the null_char_ is
314   // mainly intended for tail padding.)
315   void LabelsViaSimpleText(const NetworkIO &output, std::vector<int> *labels,
316                            std::vector<int> *xcoords);
317 
318   // Returns a string corresponding to the label starting at start. Sets *end
319   // to the next start and if non-null, *decoded to the unichar id.
320   const char *DecodeLabel(const std::vector<int> &labels, unsigned start, unsigned *end, int *decoded);
321 
322   // Returns a string corresponding to a given single label id, falling back to
323   // a default of ".." for part of a multi-label unichar-id.
324   const char *DecodeSingleLabel(int label);
325 
326 protected:
327   // The network hierarchy.
328   Network *network_;
329   // The unicharset. Only the unicharset element is serialized.
330   // Has to be a CCUtil, so Dict can point to it.
331   CCUtil ccutil_;
332   // For backward compatibility, recoder_ is serialized iff
333   // training_flags_ & TF_COMPRESS_UNICHARSET.
334   // Further encode/decode ccutil_.unicharset's ids to simplify the unicharset.
335   UnicharCompress recoder_;
336 
337   // ==Training parameters that are serialized to provide a record of them.==
338   std::string network_str_;
339   // Flags used to determine the training method of the network.
340   // See enum TrainingFlags above.
341   int32_t training_flags_;
342   // Number of actual backward training steps used.
343   int32_t training_iteration_;
344   // Index into training sample set. sample_iteration >= training_iteration_.
345   int32_t sample_iteration_;
346   // Index in softmax of null character. May take the value UNICHAR_BROKEN or
347   // ccutil_.unicharset.size().
348   int32_t null_char_;
349   // Learning rate and momentum multipliers of deltas in backprop.
350   float learning_rate_;
351   float momentum_;
352   // Smoothing factor for 2nd moment of gradients.
353   float adam_beta_;
354 
355   // === NOT SERIALIZED.
356   TRand randomizer_;
357   NetworkScratch scratch_space_;
358   // Language model (optional) to use with the beam search.
359   Dict *dict_;
360   // Beam search held between uses to optimize memory allocation/use.
361   RecodeBeamSearch *search_;
362 
363   // == Debugging parameters.==
364   // Recognition debug display window.
365   ScrollView *debug_win_;
366 };
367 
368 } // namespace tesseract.
369 
370 #endif // TESSERACT_LSTM_LSTMRECOGNIZER_H_
371