1 // Copyright 2010 Google Inc. All Rights Reserved.
2 // Author: rays@google.com (Ray Smith)
3 ///////////////////////////////////////////////////////////////////////
4 // File:        mastertrainer.h
5 // Description: Trainer to build the MasterClassifier.
6 // Author:      Ray Smith
7 //
8 // (C) Copyright 2010, Google Inc.
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 // http://www.apache.org/licenses/LICENSE-2.0
13 // Unless required by applicable law or agreed to in writing, software
14 // distributed under the License is distributed on an "AS IS" BASIS,
15 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 // See the License for the specific language governing permissions and
17 // limitations under the License.
18 //
19 ///////////////////////////////////////////////////////////////////////
20 
21 #ifndef TESSERACT_TRAINING_MASTERTRAINER_H_
22 #define TESSERACT_TRAINING_MASTERTRAINER_H_
23 
24 #include "export.h"
25 
26 #include "classify.h"
27 #include "cluster.h"
28 #include "elst.h"
29 #include "errorcounter.h"
30 #include "featdefs.h"
31 #include "fontinfo.h"
32 #include "indexmapbidi.h"
33 #include "intfeaturemap.h"
34 #include "intfeaturespace.h"
35 #include "intfx.h"
36 #include "intmatcher.h"
37 #include "params.h"
38 #include "shapetable.h"
39 #include "trainingsample.h"
40 #include "trainingsampleset.h"
41 #include "unicharset.h"
42 
43 namespace tesseract {
44 
45 class ShapeClassifier;
46 
47 // Simple struct to hold the distance between two shapes during clustering.
48 struct ShapeDist {
ShapeDistShapeDist49   ShapeDist() : shape1(0), shape2(0), distance(0.0f) {}
ShapeDistShapeDist50   ShapeDist(int s1, int s2, float dist) : shape1(s1), shape2(s2), distance(dist) {}
51 
52   // Sort operator to sort in ascending order of distance.
53   bool operator<(const ShapeDist &other) const {
54     return distance < other.distance;
55   }
56 
57   int shape1;
58   int shape2;
59   float distance;
60 };
61 
62 // Class to encapsulate training processes that use the TrainingSampleSet.
63 // Initially supports shape clustering and mftrainining.
64 // Other important features of the MasterTrainer are conditioning the data
65 // by outlier elimination, replication with perturbation, and serialization.
66 class TESS_COMMON_TRAINING_API MasterTrainer {
67 public:
68   MasterTrainer(NormalizationMode norm_mode, bool shape_analysis, bool replicate_samples,
69                 int debug_level);
70   ~MasterTrainer();
71 
72   // Writes to the given file. Returns false in case of error.
73   bool Serialize(FILE *fp) const;
74 
75   // Loads an initial unicharset, or sets one up if the file cannot be read.
76   void LoadUnicharset(const char *filename);
77 
78   // Sets the feature space definition.
SetFeatureSpace(const IntFeatureSpace & fs)79   void SetFeatureSpace(const IntFeatureSpace &fs) {
80     feature_space_ = fs;
81     feature_map_.Init(fs);
82   }
83 
84   // Reads the samples and their features from the given file,
85   // adding them to the trainer with the font_id from the content of the file.
86   // If verification, then these are verification samples, not training.
87   void ReadTrainingSamples(const char *page_name, const FEATURE_DEFS_STRUCT &feature_defs,
88                            bool verification);
89 
90   // Adds the given single sample to the trainer, setting the classid
91   // appropriately from the given unichar_str.
92   void AddSample(bool verification, const char *unichar_str, TrainingSample *sample);
93 
94   // Loads all pages from the given tif filename and append to page_images_.
95   // Must be called after ReadTrainingSamples, as the current number of images
96   // is used as an offset for page numbers in the samples.
97   void LoadPageImages(const char *filename);
98 
99   // Cleans up the samples after initial load from the tr files, and prior to
100   // saving the MasterTrainer:
101   // Remaps fragmented chars if running shape analysis.
102   // Sets up the samples appropriately for class/fontwise access.
103   // Deletes outlier samples.
104   void PostLoadCleanup();
105 
106   // Gets the samples ready for training. Use after both
107   // ReadTrainingSamples+PostLoadCleanup or DeSerialize.
108   // Re-indexes the features and computes canonical and cloud features.
109   void PreTrainingSetup();
110 
111   // Sets up the master_shapes_ table, which tells which fonts should stay
112   // together until they get to a leaf node classifier.
113   void SetupMasterShapes();
114 
115   // Adds the junk_samples_ to the main samples_ set. Junk samples are initially
116   // fragments and n-grams (all incorrectly segmented characters).
117   // Various training functions may result in incorrectly segmented characters
118   // being added to the unicharset of the main samples, perhaps because they
119   // form a "radical" decomposition of some (Indic) grapheme, or because they
120   // just look the same as a real character (like rn/m)
121   // This function moves all the junk samples, to the main samples_ set, but
122   // desirable junk, being any sample for which the unichar already exists in
123   // the samples_ unicharset gets the unichar-ids re-indexed to match, but
124   // anything else gets re-marked as unichar_id 0 (space character) to identify
125   // it as junk to the error counter.
126   void IncludeJunk();
127 
128   // Replicates the samples and perturbs them if the enable_replication_ flag
129   // is set. MUST be used after the last call to OrganizeByFontAndClass on
130   // the training samples, ie after IncludeJunk if it is going to be used, as
131   // OrganizeByFontAndClass will eat the replicated samples into the regular
132   // samples.
133   void ReplicateAndRandomizeSamplesIfRequired();
134 
135   // Loads the basic font properties file into fontinfo_table_.
136   // Returns false on failure.
137   bool LoadFontInfo(const char *filename);
138 
139   // Loads the xheight font properties file into xheights_.
140   // Returns false on failure.
141   bool LoadXHeights(const char *filename);
142 
143   // Reads spacing stats from filename and adds them to fontinfo_table.
144   // Returns false on failure.
145   bool AddSpacingInfo(const char *filename);
146 
147   // Returns the font id corresponding to the given font name.
148   // Returns -1 if the font cannot be found.
149   int GetFontInfoId(const char *font_name);
150   // Returns the font_id of the closest matching font name to the given
151   // filename. It is assumed that a substring of the filename will match
152   // one of the fonts. If more than one is matched, the longest is returned.
153   int GetBestMatchingFontInfoId(const char *filename);
154 
155   // Returns the filename of the tr file corresponding to the command-line
156   // argument with the given index.
GetTRFileName(int index)157   const std::string &GetTRFileName(int index) const {
158     return tr_filenames_[index];
159   }
160 
161   // Sets up a flat shapetable with one shape per class/font combination.
162   void SetupFlatShapeTable(ShapeTable *shape_table);
163 
164   // Sets up a Clusterer for mftraining on a single shape_id.
165   // Call FreeClusterer on the return value after use.
166   CLUSTERER *SetupForClustering(const ShapeTable &shape_table,
167                                 const FEATURE_DEFS_STRUCT &feature_defs, int shape_id,
168                                 int *num_samples);
169 
170   // Writes the given float_classes (produced by SetupForFloat2Int) as inttemp
171   // to the given inttemp_file, and the corresponding pffmtable.
172   // The unicharset is the original encoding of graphemes, and shape_set should
173   // match the size of the shape_table, and may possibly be totally fake.
174   void WriteInttempAndPFFMTable(const UNICHARSET &unicharset, const UNICHARSET &shape_set,
175                                 const ShapeTable &shape_table, CLASS_STRUCT *float_classes,
176                                 const char *inttemp_file, const char *pffmtable_file);
177 
unicharset()178   const UNICHARSET &unicharset() const {
179     return samples_.unicharset();
180   }
GetSamples()181   TrainingSampleSet *GetSamples() {
182     return &samples_;
183   }
master_shapes()184   const ShapeTable &master_shapes() const {
185     return master_shapes_;
186   }
187 
188   // Generates debug output relating to the canonical distance between the
189   // two given UTF8 grapheme strings.
190   void DebugCanonical(const char *unichar_str1, const char *unichar_str2);
191 #ifndef GRAPHICS_DISABLED
192   // Debugging for cloud/canonical features.
193   // Displays a Features window containing:
194   // If unichar_str2 is in the unicharset, and canonical_font is non-negative,
195   // displays the canonical features of the char/font combination in red.
196   // If unichar_str1 is in the unicharset, and cloud_font is non-negative,
197   // displays the cloud feature of the char/font combination in green.
198   // The canonical features are drawn first to show which ones have no
199   // matches in the cloud features.
200   // Until the features window is destroyed, each click in the features window
201   // will display the samples that have that feature in a separate window.
202   void DisplaySamples(const char *unichar_str1, int cloud_font, const char *unichar_str2,
203                       int canonical_font);
204 #endif // !GRAPHICS_DISABLED
205 
206   void TestClassifierVOld(bool replicate_samples, ShapeClassifier *test_classifier,
207                           ShapeClassifier *old_classifier);
208 
209   // Tests the given test_classifier on the internal samples.
210   // See TestClassifier for details.
211   void TestClassifierOnSamples(CountTypes error_mode, int report_level, bool replicate_samples,
212                                ShapeClassifier *test_classifier, std::string *report_string);
213   // Tests the given test_classifier on the given samples
214   // error_mode indicates what counts as an error.
215   // report_levels:
216   // 0 = no output.
217   // 1 = bottom-line error rate.
218   // 2 = bottom-line error rate + time.
219   // 3 = font-level error rate + time.
220   // 4 = list of all errors + short classifier debug output on 16 errors.
221   // 5 = list of all errors + short classifier debug output on 25 errors.
222   // If replicate_samples is true, then the test is run on an extended test
223   // sample including replicated and systematically perturbed samples.
224   // If report_string is non-nullptr, a summary of the results for each font
225   // is appended to the report_string.
226   double TestClassifier(CountTypes error_mode, int report_level, bool replicate_samples,
227                         TrainingSampleSet *samples, ShapeClassifier *test_classifier,
228                         std::string *report_string);
229 
230   // Returns the average (in some sense) distance between the two given
231   // shapes, which may contain multiple fonts and/or unichars.
232   // This function is public to facilitate testing.
233   float ShapeDistance(const ShapeTable &shapes, int s1, int s2);
234 
235 private:
236   // Replaces samples that are always fragmented with the corresponding
237   // fragment samples.
238   void ReplaceFragmentedSamples();
239 
240   // Runs a hierarchical agglomerative clustering to merge shapes in the given
241   // shape_table, while satisfying the given constraints:
242   // * End with at least min_shapes left in shape_table,
243   // * No shape shall have more than max_shape_unichars in it,
244   // * Don't merge shapes where the distance between them exceeds max_dist.
245   void ClusterShapes(int min_shapes, int max_shape_unichars, float max_dist,
246                      ShapeTable *shape_table);
247 
248 private:
249   NormalizationMode norm_mode_;
250   // Character set we are training for.
251   UNICHARSET unicharset_;
252   // Original feature space. Subspace mapping is contained in feature_map_.
253   IntFeatureSpace feature_space_;
254   TrainingSampleSet samples_;
255   TrainingSampleSet junk_samples_;
256   TrainingSampleSet verify_samples_;
257   // Master shape table defines what fonts stay together until the leaves.
258   ShapeTable master_shapes_;
259   // Flat shape table has each unichar/font id pair in a separate shape.
260   ShapeTable flat_shapes_;
261   // Font metrics gathered from multiple files.
262   FontInfoTable fontinfo_table_;
263   // Array of xheights indexed by font ids in fontinfo_table_;
264   std::vector<int32_t> xheights_;
265 
266   // Non-serialized data initialized by other means or used temporarily
267   // during loading of training samples.
268   // Number of different class labels in unicharset_.
269   int charsetsize_;
270   // Flag to indicate that we are running shape analysis and need fragments
271   // fixing.
272   bool enable_shape_analysis_;
273   // Flag to indicate that sample replication is required.
274   bool enable_replication_;
275   // Array of classids of fragments that replace the correctly segmented chars.
276   int *fragments_;
277   // Classid of previous correctly segmented sample that was added.
278   int prev_unichar_id_;
279   // Debug output control.
280   int debug_level_;
281   // Feature map used to construct reduced feature spaces for compact
282   // classifiers.
283   IntFeatureMap feature_map_;
284   // Vector of Pix pointers used for classifiers that need the image.
285   // Indexed by page_num_ in the samples.
286   // These images are owned by the trainer and need to be pixDestroyed.
287   std::vector<Image > page_images_;
288   // Vector of filenames of loaded tr files.
289   std::vector<std::string> tr_filenames_;
290 };
291 
292 } // namespace tesseract.
293 
294 #endif // TESSERACT_TRAINING_MASTERTRAINER_H_
295