1 // Copyright 2010 Google Inc. All Rights Reserved. 2 // Author: rays@google.com (Ray Smith) 3 /////////////////////////////////////////////////////////////////////// 4 // File: mastertrainer.h 5 // Description: Trainer to build the MasterClassifier. 6 // Author: Ray Smith 7 // 8 // (C) Copyright 2010, Google Inc. 9 // Licensed under the Apache License, Version 2.0 (the "License"); 10 // you may not use this file except in compliance with the License. 11 // You may obtain a copy of the License at 12 // http://www.apache.org/licenses/LICENSE-2.0 13 // Unless required by applicable law or agreed to in writing, software 14 // distributed under the License is distributed on an "AS IS" BASIS, 15 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 // See the License for the specific language governing permissions and 17 // limitations under the License. 18 // 19 /////////////////////////////////////////////////////////////////////// 20 21 #ifndef TESSERACT_TRAINING_MASTERTRAINER_H_ 22 #define TESSERACT_TRAINING_MASTERTRAINER_H_ 23 24 #include "export.h" 25 26 #include "classify.h" 27 #include "cluster.h" 28 #include "elst.h" 29 #include "errorcounter.h" 30 #include "featdefs.h" 31 #include "fontinfo.h" 32 #include "indexmapbidi.h" 33 #include "intfeaturemap.h" 34 #include "intfeaturespace.h" 35 #include "intfx.h" 36 #include "intmatcher.h" 37 #include "params.h" 38 #include "shapetable.h" 39 #include "trainingsample.h" 40 #include "trainingsampleset.h" 41 #include "unicharset.h" 42 43 namespace tesseract { 44 45 class ShapeClassifier; 46 47 // Simple struct to hold the distance between two shapes during clustering. 48 struct ShapeDist { ShapeDistShapeDist49 ShapeDist() : shape1(0), shape2(0), distance(0.0f) {} ShapeDistShapeDist50 ShapeDist(int s1, int s2, float dist) : shape1(s1), shape2(s2), distance(dist) {} 51 52 // Sort operator to sort in ascending order of distance. 53 bool operator<(const ShapeDist &other) const { 54 return distance < other.distance; 55 } 56 57 int shape1; 58 int shape2; 59 float distance; 60 }; 61 62 // Class to encapsulate training processes that use the TrainingSampleSet. 63 // Initially supports shape clustering and mftrainining. 64 // Other important features of the MasterTrainer are conditioning the data 65 // by outlier elimination, replication with perturbation, and serialization. 66 class TESS_COMMON_TRAINING_API MasterTrainer { 67 public: 68 MasterTrainer(NormalizationMode norm_mode, bool shape_analysis, bool replicate_samples, 69 int debug_level); 70 ~MasterTrainer(); 71 72 // Writes to the given file. Returns false in case of error. 73 bool Serialize(FILE *fp) const; 74 75 // Loads an initial unicharset, or sets one up if the file cannot be read. 76 void LoadUnicharset(const char *filename); 77 78 // Sets the feature space definition. SetFeatureSpace(const IntFeatureSpace & fs)79 void SetFeatureSpace(const IntFeatureSpace &fs) { 80 feature_space_ = fs; 81 feature_map_.Init(fs); 82 } 83 84 // Reads the samples and their features from the given file, 85 // adding them to the trainer with the font_id from the content of the file. 86 // If verification, then these are verification samples, not training. 87 void ReadTrainingSamples(const char *page_name, const FEATURE_DEFS_STRUCT &feature_defs, 88 bool verification); 89 90 // Adds the given single sample to the trainer, setting the classid 91 // appropriately from the given unichar_str. 92 void AddSample(bool verification, const char *unichar_str, TrainingSample *sample); 93 94 // Loads all pages from the given tif filename and append to page_images_. 95 // Must be called after ReadTrainingSamples, as the current number of images 96 // is used as an offset for page numbers in the samples. 97 void LoadPageImages(const char *filename); 98 99 // Cleans up the samples after initial load from the tr files, and prior to 100 // saving the MasterTrainer: 101 // Remaps fragmented chars if running shape analysis. 102 // Sets up the samples appropriately for class/fontwise access. 103 // Deletes outlier samples. 104 void PostLoadCleanup(); 105 106 // Gets the samples ready for training. Use after both 107 // ReadTrainingSamples+PostLoadCleanup or DeSerialize. 108 // Re-indexes the features and computes canonical and cloud features. 109 void PreTrainingSetup(); 110 111 // Sets up the master_shapes_ table, which tells which fonts should stay 112 // together until they get to a leaf node classifier. 113 void SetupMasterShapes(); 114 115 // Adds the junk_samples_ to the main samples_ set. Junk samples are initially 116 // fragments and n-grams (all incorrectly segmented characters). 117 // Various training functions may result in incorrectly segmented characters 118 // being added to the unicharset of the main samples, perhaps because they 119 // form a "radical" decomposition of some (Indic) grapheme, or because they 120 // just look the same as a real character (like rn/m) 121 // This function moves all the junk samples, to the main samples_ set, but 122 // desirable junk, being any sample for which the unichar already exists in 123 // the samples_ unicharset gets the unichar-ids re-indexed to match, but 124 // anything else gets re-marked as unichar_id 0 (space character) to identify 125 // it as junk to the error counter. 126 void IncludeJunk(); 127 128 // Replicates the samples and perturbs them if the enable_replication_ flag 129 // is set. MUST be used after the last call to OrganizeByFontAndClass on 130 // the training samples, ie after IncludeJunk if it is going to be used, as 131 // OrganizeByFontAndClass will eat the replicated samples into the regular 132 // samples. 133 void ReplicateAndRandomizeSamplesIfRequired(); 134 135 // Loads the basic font properties file into fontinfo_table_. 136 // Returns false on failure. 137 bool LoadFontInfo(const char *filename); 138 139 // Loads the xheight font properties file into xheights_. 140 // Returns false on failure. 141 bool LoadXHeights(const char *filename); 142 143 // Reads spacing stats from filename and adds them to fontinfo_table. 144 // Returns false on failure. 145 bool AddSpacingInfo(const char *filename); 146 147 // Returns the font id corresponding to the given font name. 148 // Returns -1 if the font cannot be found. 149 int GetFontInfoId(const char *font_name); 150 // Returns the font_id of the closest matching font name to the given 151 // filename. It is assumed that a substring of the filename will match 152 // one of the fonts. If more than one is matched, the longest is returned. 153 int GetBestMatchingFontInfoId(const char *filename); 154 155 // Returns the filename of the tr file corresponding to the command-line 156 // argument with the given index. GetTRFileName(int index)157 const std::string &GetTRFileName(int index) const { 158 return tr_filenames_[index]; 159 } 160 161 // Sets up a flat shapetable with one shape per class/font combination. 162 void SetupFlatShapeTable(ShapeTable *shape_table); 163 164 // Sets up a Clusterer for mftraining on a single shape_id. 165 // Call FreeClusterer on the return value after use. 166 CLUSTERER *SetupForClustering(const ShapeTable &shape_table, 167 const FEATURE_DEFS_STRUCT &feature_defs, int shape_id, 168 int *num_samples); 169 170 // Writes the given float_classes (produced by SetupForFloat2Int) as inttemp 171 // to the given inttemp_file, and the corresponding pffmtable. 172 // The unicharset is the original encoding of graphemes, and shape_set should 173 // match the size of the shape_table, and may possibly be totally fake. 174 void WriteInttempAndPFFMTable(const UNICHARSET &unicharset, const UNICHARSET &shape_set, 175 const ShapeTable &shape_table, CLASS_STRUCT *float_classes, 176 const char *inttemp_file, const char *pffmtable_file); 177 unicharset()178 const UNICHARSET &unicharset() const { 179 return samples_.unicharset(); 180 } GetSamples()181 TrainingSampleSet *GetSamples() { 182 return &samples_; 183 } master_shapes()184 const ShapeTable &master_shapes() const { 185 return master_shapes_; 186 } 187 188 // Generates debug output relating to the canonical distance between the 189 // two given UTF8 grapheme strings. 190 void DebugCanonical(const char *unichar_str1, const char *unichar_str2); 191 #ifndef GRAPHICS_DISABLED 192 // Debugging for cloud/canonical features. 193 // Displays a Features window containing: 194 // If unichar_str2 is in the unicharset, and canonical_font is non-negative, 195 // displays the canonical features of the char/font combination in red. 196 // If unichar_str1 is in the unicharset, and cloud_font is non-negative, 197 // displays the cloud feature of the char/font combination in green. 198 // The canonical features are drawn first to show which ones have no 199 // matches in the cloud features. 200 // Until the features window is destroyed, each click in the features window 201 // will display the samples that have that feature in a separate window. 202 void DisplaySamples(const char *unichar_str1, int cloud_font, const char *unichar_str2, 203 int canonical_font); 204 #endif // !GRAPHICS_DISABLED 205 206 void TestClassifierVOld(bool replicate_samples, ShapeClassifier *test_classifier, 207 ShapeClassifier *old_classifier); 208 209 // Tests the given test_classifier on the internal samples. 210 // See TestClassifier for details. 211 void TestClassifierOnSamples(CountTypes error_mode, int report_level, bool replicate_samples, 212 ShapeClassifier *test_classifier, std::string *report_string); 213 // Tests the given test_classifier on the given samples 214 // error_mode indicates what counts as an error. 215 // report_levels: 216 // 0 = no output. 217 // 1 = bottom-line error rate. 218 // 2 = bottom-line error rate + time. 219 // 3 = font-level error rate + time. 220 // 4 = list of all errors + short classifier debug output on 16 errors. 221 // 5 = list of all errors + short classifier debug output on 25 errors. 222 // If replicate_samples is true, then the test is run on an extended test 223 // sample including replicated and systematically perturbed samples. 224 // If report_string is non-nullptr, a summary of the results for each font 225 // is appended to the report_string. 226 double TestClassifier(CountTypes error_mode, int report_level, bool replicate_samples, 227 TrainingSampleSet *samples, ShapeClassifier *test_classifier, 228 std::string *report_string); 229 230 // Returns the average (in some sense) distance between the two given 231 // shapes, which may contain multiple fonts and/or unichars. 232 // This function is public to facilitate testing. 233 float ShapeDistance(const ShapeTable &shapes, int s1, int s2); 234 235 private: 236 // Replaces samples that are always fragmented with the corresponding 237 // fragment samples. 238 void ReplaceFragmentedSamples(); 239 240 // Runs a hierarchical agglomerative clustering to merge shapes in the given 241 // shape_table, while satisfying the given constraints: 242 // * End with at least min_shapes left in shape_table, 243 // * No shape shall have more than max_shape_unichars in it, 244 // * Don't merge shapes where the distance between them exceeds max_dist. 245 void ClusterShapes(int min_shapes, int max_shape_unichars, float max_dist, 246 ShapeTable *shape_table); 247 248 private: 249 NormalizationMode norm_mode_; 250 // Character set we are training for. 251 UNICHARSET unicharset_; 252 // Original feature space. Subspace mapping is contained in feature_map_. 253 IntFeatureSpace feature_space_; 254 TrainingSampleSet samples_; 255 TrainingSampleSet junk_samples_; 256 TrainingSampleSet verify_samples_; 257 // Master shape table defines what fonts stay together until the leaves. 258 ShapeTable master_shapes_; 259 // Flat shape table has each unichar/font id pair in a separate shape. 260 ShapeTable flat_shapes_; 261 // Font metrics gathered from multiple files. 262 FontInfoTable fontinfo_table_; 263 // Array of xheights indexed by font ids in fontinfo_table_; 264 std::vector<int32_t> xheights_; 265 266 // Non-serialized data initialized by other means or used temporarily 267 // during loading of training samples. 268 // Number of different class labels in unicharset_. 269 int charsetsize_; 270 // Flag to indicate that we are running shape analysis and need fragments 271 // fixing. 272 bool enable_shape_analysis_; 273 // Flag to indicate that sample replication is required. 274 bool enable_replication_; 275 // Array of classids of fragments that replace the correctly segmented chars. 276 int *fragments_; 277 // Classid of previous correctly segmented sample that was added. 278 int prev_unichar_id_; 279 // Debug output control. 280 int debug_level_; 281 // Feature map used to construct reduced feature spaces for compact 282 // classifiers. 283 IntFeatureMap feature_map_; 284 // Vector of Pix pointers used for classifiers that need the image. 285 // Indexed by page_num_ in the samples. 286 // These images are owned by the trainer and need to be pixDestroyed. 287 std::vector<Image > page_images_; 288 // Vector of filenames of loaded tr files. 289 std::vector<std::string> tr_filenames_; 290 }; 291 292 } // namespace tesseract. 293 294 #endif // TESSERACT_TRAINING_MASTERTRAINER_H_ 295