1 /////////////////////////////////////////////////////////////////////// 2 // File: lstmrecognizer.h 3 // Description: Top-level line recognizer class for LSTM-based networks. 4 // Author: Ray Smith 5 // 6 // (C) Copyright 2013, Google Inc. 7 // Licensed under the Apache License, Version 2.0 (the "License"); 8 // you may not use this file except in compliance with the License. 9 // You may obtain a copy of the License at 10 // http://www.apache.org/licenses/LICENSE-2.0 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 /////////////////////////////////////////////////////////////////////// 17 18 #ifndef TESSERACT_LSTM_LSTMRECOGNIZER_H_ 19 #define TESSERACT_LSTM_LSTMRECOGNIZER_H_ 20 21 #include "ccutil.h" 22 #include "helpers.h" 23 #include "matrix.h" 24 #include "network.h" 25 #include "networkscratch.h" 26 #include "params.h" 27 #include "recodebeam.h" 28 #include "series.h" 29 #include "unicharcompress.h" 30 31 class BLOB_CHOICE_IT; 32 struct Pix; 33 class ROW_RES; 34 class ScrollView; 35 class TBOX; 36 class WERD_RES; 37 38 namespace tesseract { 39 40 class Dict; 41 class ImageData; 42 43 // Enum indicating training mode control flags. 44 enum TrainingFlags { 45 TF_INT_MODE = 1, 46 TF_COMPRESS_UNICHARSET = 64, 47 }; 48 49 // Top-level line recognizer class for LSTM-based networks. 50 // Note that a sub-class, LSTMTrainer is used for training. 51 class TESS_API LSTMRecognizer { 52 public: 53 LSTMRecognizer(); 54 LSTMRecognizer(const std::string &language_data_path_prefix); 55 ~LSTMRecognizer(); 56 NumOutputs()57 int NumOutputs() const { 58 return network_->NumOutputs(); 59 } 60 61 // Return the training iterations. training_iteration()62 int training_iteration() const { 63 return training_iteration_; 64 } 65 66 // Return the sample iterations. sample_iteration()67 int sample_iteration() const { 68 return sample_iteration_; 69 } 70 71 // Return the learning rate. learning_rate()72 float learning_rate() const { 73 return learning_rate_; 74 } 75 OutputLossType()76 LossType OutputLossType() const { 77 if (network_ == nullptr) { 78 return LT_NONE; 79 } 80 StaticShape shape; 81 shape = network_->OutputShape(shape); 82 return shape.loss_type(); 83 } SimpleTextOutput()84 bool SimpleTextOutput() const { 85 return OutputLossType() == LT_SOFTMAX; 86 } IsIntMode()87 bool IsIntMode() const { 88 return (training_flags_ & TF_INT_MODE) != 0; 89 } 90 // True if recoder_ is active to re-encode text to a smaller space. IsRecoding()91 bool IsRecoding() const { 92 return (training_flags_ & TF_COMPRESS_UNICHARSET) != 0; 93 } 94 // Returns true if the network is a TensorFlow network. IsTensorFlow()95 bool IsTensorFlow() const { 96 return network_->type() == NT_TENSORFLOW; 97 } 98 // Returns a vector of layer ids that can be passed to other layer functions 99 // to access a specific layer. EnumerateLayers()100 std::vector<std::string> EnumerateLayers() const { 101 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES); 102 auto *series = static_cast<Series *>(network_); 103 std::vector<std::string> layers; 104 series->EnumerateLayers(nullptr, layers); 105 return layers; 106 } 107 // Returns a specific layer from its id (from EnumerateLayers). GetLayer(const std::string & id)108 Network *GetLayer(const std::string &id) const { 109 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES); 110 ASSERT_HOST(id.length() > 1 && id[0] == ':'); 111 auto *series = static_cast<Series *>(network_); 112 return series->GetLayer(&id[1]); 113 } 114 // Returns the learning rate of the layer from its id. GetLayerLearningRate(const std::string & id)115 float GetLayerLearningRate(const std::string &id) const { 116 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES); 117 if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) { 118 ASSERT_HOST(id.length() > 1 && id[0] == ':'); 119 auto *series = static_cast<Series *>(network_); 120 return series->LayerLearningRate(&id[1]); 121 } else { 122 return learning_rate_; 123 } 124 } 125 126 // Return the network string. GetNetwork()127 const char *GetNetwork() const { 128 return network_str_.c_str(); 129 } 130 131 // Return the adam beta. GetAdamBeta()132 float GetAdamBeta() const { 133 return adam_beta_; 134 } 135 136 // Return the momentum. GetMomentum()137 float GetMomentum() const { 138 return momentum_; 139 } 140 141 // Multiplies the all the learning rate(s) by the given factor. ScaleLearningRate(double factor)142 void ScaleLearningRate(double factor) { 143 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES); 144 learning_rate_ *= factor; 145 if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) { 146 std::vector<std::string> layers = EnumerateLayers(); 147 for (auto &layer : layers) { 148 ScaleLayerLearningRate(layer, factor); 149 } 150 } 151 } 152 // Multiplies the learning rate of the layer with id, by the given factor. ScaleLayerLearningRate(const std::string & id,double factor)153 void ScaleLayerLearningRate(const std::string &id, double factor) { 154 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES); 155 ASSERT_HOST(id.length() > 1 && id[0] == ':'); 156 auto *series = static_cast<Series *>(network_); 157 series->ScaleLayerLearningRate(&id[1], factor); 158 } 159 160 // Set the all the learning rate(s) to the given value. SetLearningRate(float learning_rate)161 void SetLearningRate(float learning_rate) 162 { 163 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES); 164 learning_rate_ = learning_rate; 165 if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) { 166 for (auto &id : EnumerateLayers()) { 167 SetLayerLearningRate(id, learning_rate); 168 } 169 } 170 } 171 // Set the learning rate of the layer with id, by the given value. SetLayerLearningRate(const std::string & id,float learning_rate)172 void SetLayerLearningRate(const std::string &id, float learning_rate) 173 { 174 ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES); 175 ASSERT_HOST(id.length() > 1 && id[0] == ':'); 176 auto *series = static_cast<Series *>(network_); 177 series->SetLayerLearningRate(&id[1], learning_rate); 178 } 179 180 // Converts the network to int if not already. ConvertToInt()181 void ConvertToInt() { 182 if ((training_flags_ & TF_INT_MODE) == 0) { 183 network_->ConvertToInt(); 184 training_flags_ |= TF_INT_MODE; 185 } 186 } 187 188 // Provides access to the UNICHARSET that this classifier works with. GetUnicharset()189 const UNICHARSET &GetUnicharset() const { 190 return ccutil_.unicharset; 191 } GetUnicharset()192 UNICHARSET &GetUnicharset() { 193 return ccutil_.unicharset; 194 } 195 // Provides access to the UnicharCompress that this classifier works with. GetRecoder()196 const UnicharCompress &GetRecoder() const { 197 return recoder_; 198 } 199 // Provides access to the Dict that this classifier works with. GetDict()200 const Dict *GetDict() const { 201 return dict_; 202 } GetDict()203 Dict *GetDict() { 204 return dict_; 205 } 206 // Sets the sample iteration to the given value. The sample_iteration_ 207 // determines the seed for the random number generator. The training 208 // iteration is incremented only by a successful training iteration. SetIteration(int iteration)209 void SetIteration(int iteration) { 210 sample_iteration_ = iteration; 211 } 212 // Accessors for textline image normalization. NumInputs()213 int NumInputs() const { 214 return network_->NumInputs(); 215 } 216 217 // Return the null char index. null_char()218 int null_char() const { 219 return null_char_; 220 } 221 222 // Loads a model from mgr, including the dictionary only if lang is not null. 223 bool Load(const ParamsVectors *params, const std::string &lang, TessdataManager *mgr); 224 225 // Writes to the given file. Returns false in case of error. 226 // If mgr contains a unicharset and recoder, then they are not encoded to fp. 227 bool Serialize(const TessdataManager *mgr, TFile *fp) const; 228 // Reads from the given file. Returns false in case of error. 229 // If mgr contains a unicharset and recoder, then they are taken from there, 230 // otherwise, they are part of the serialization in fp. 231 bool DeSerialize(const TessdataManager *mgr, TFile *fp); 232 // Loads the charsets from mgr. 233 bool LoadCharsets(const TessdataManager *mgr); 234 // Loads the Recoder. 235 bool LoadRecoder(TFile *fp); 236 // Loads the dictionary if possible from the traineddata file. 237 // Prints a warning message, and returns false but otherwise fails silently 238 // and continues to work without it if loading fails. 239 // Note that dictionary load is independent from DeSerialize, but dependent 240 // on the unicharset matching. This enables training to deserialize a model 241 // from checkpoint or restore without having to go back and reload the 242 // dictionary. 243 bool LoadDictionary(const ParamsVectors *params, const std::string &lang, TessdataManager *mgr); 244 245 // Recognizes the line image, contained within image_data, returning the 246 // recognized tesseract WERD_RES for the words. 247 // If invert, tries inverted as well if the normal interpretation doesn't 248 // produce a good enough result. The line_box is used for computing the 249 // box_word in the output words. worst_dict_cert is the worst certainty that 250 // will be used in a dictionary word. 251 void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, 252 const TBOX &line_box, PointerVector<WERD_RES> *words, int lstm_choice_mode = 0, 253 int lstm_choice_amount = 5); 254 255 // Helper computes min and mean best results in the output. 256 void OutputStats(const NetworkIO &outputs, float *min_output, float *mean_output, float *sd); 257 // Recognizes the image_data, returning the labels, 258 // scores, and corresponding pairs of start, end x-coords in coords. 259 // Returned in scale_factor is the reduction factor 260 // between the image and the output coords, for computing bounding boxes. 261 // If re_invert is true, the input is inverted back to its original 262 // photometric interpretation if inversion is attempted but fails to 263 // improve the results. This ensures that outputs contains the correct 264 // forward outputs for the best photometric interpretation. 265 // inputs is filled with the used inputs to the network. 266 bool RecognizeLine(const ImageData &image_data, bool invert, bool debug, bool re_invert, 267 bool upside_down, float *scale_factor, NetworkIO *inputs, NetworkIO *outputs); 268 269 // Converts an array of labels to utf-8, whether or not the labels are 270 // augmented with character boundaries. 271 std::string DecodeLabels(const std::vector<int> &labels); 272 273 // Displays the forward results in a window with the characters and 274 // boundaries as determined by the labels and label_coords. 275 void DisplayForward(const NetworkIO &inputs, const std::vector<int> &labels, 276 const std::vector<int> &label_coords, const char *window_name, 277 ScrollView **window); 278 // Converts the network output to a sequence of labels. Outputs labels, scores 279 // and start xcoords of each char, and each null_char_, with an additional 280 // final xcoord for the end of the output. 281 // The conversion method is determined by internal state. 282 void LabelsFromOutputs(const NetworkIO &outputs, std::vector<int> *labels, 283 std::vector<int> *xcoords); 284 285 protected: 286 // Sets the random seed from the sample_iteration_; SetRandomSeed()287 void SetRandomSeed() { 288 int64_t seed = static_cast<int64_t>(sample_iteration_) * 0x10000001; 289 randomizer_.set_seed(seed); 290 randomizer_.IntRand(); 291 } 292 293 // Displays the labels and cuts at the corresponding xcoords. 294 // Size of labels should match xcoords. 295 void DisplayLSTMOutput(const std::vector<int> &labels, const std::vector<int> &xcoords, 296 int height, ScrollView *window); 297 298 // Prints debug output detailing the activation path that is implied by the 299 // xcoords. 300 void DebugActivationPath(const NetworkIO &outputs, const std::vector<int> &labels, 301 const std::vector<int> &xcoords); 302 303 // Prints debug output detailing activations and 2nd choice over a range 304 // of positions. 305 void DebugActivationRange(const NetworkIO &outputs, const char *label, int best_choice, 306 int x_start, int x_end); 307 308 // As LabelsViaCTC except that this function constructs the best path that 309 // contains only legal sequences of subcodes for recoder_. 310 void LabelsViaReEncode(const NetworkIO &output, std::vector<int> *labels, 311 std::vector<int> *xcoords); 312 // Converts the network output to a sequence of labels, with scores, using 313 // the simple character model (each position is a char, and the null_char_ is 314 // mainly intended for tail padding.) 315 void LabelsViaSimpleText(const NetworkIO &output, std::vector<int> *labels, 316 std::vector<int> *xcoords); 317 318 // Returns a string corresponding to the label starting at start. Sets *end 319 // to the next start and if non-null, *decoded to the unichar id. 320 const char *DecodeLabel(const std::vector<int> &labels, unsigned start, unsigned *end, int *decoded); 321 322 // Returns a string corresponding to a given single label id, falling back to 323 // a default of ".." for part of a multi-label unichar-id. 324 const char *DecodeSingleLabel(int label); 325 326 protected: 327 // The network hierarchy. 328 Network *network_; 329 // The unicharset. Only the unicharset element is serialized. 330 // Has to be a CCUtil, so Dict can point to it. 331 CCUtil ccutil_; 332 // For backward compatibility, recoder_ is serialized iff 333 // training_flags_ & TF_COMPRESS_UNICHARSET. 334 // Further encode/decode ccutil_.unicharset's ids to simplify the unicharset. 335 UnicharCompress recoder_; 336 337 // ==Training parameters that are serialized to provide a record of them.== 338 std::string network_str_; 339 // Flags used to determine the training method of the network. 340 // See enum TrainingFlags above. 341 int32_t training_flags_; 342 // Number of actual backward training steps used. 343 int32_t training_iteration_; 344 // Index into training sample set. sample_iteration >= training_iteration_. 345 int32_t sample_iteration_; 346 // Index in softmax of null character. May take the value UNICHAR_BROKEN or 347 // ccutil_.unicharset.size(). 348 int32_t null_char_; 349 // Learning rate and momentum multipliers of deltas in backprop. 350 float learning_rate_; 351 float momentum_; 352 // Smoothing factor for 2nd moment of gradients. 353 float adam_beta_; 354 355 // === NOT SERIALIZED. 356 TRand randomizer_; 357 NetworkScratch scratch_space_; 358 // Language model (optional) to use with the beam search. 359 Dict *dict_; 360 // Beam search held between uses to optimize memory allocation/use. 361 RecodeBeamSearch *search_; 362 363 // == Debugging parameters.== 364 // Recognition debug display window. 365 ScrollView *debug_win_; 366 }; 367 368 } // namespace tesseract. 369 370 #endif // TESSERACT_LSTM_LSTMRECOGNIZER_H_ 371