1 // Copyright 2010 Google Inc. All Rights Reserved.
2 // Author: rays@google.com (Ray Smith)
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 //
14 ///////////////////////////////////////////////////////////////////////
15 
16 #ifndef TESSERACT_TRAINING_TRAININGSAMPLESET_H_
17 #define TESSERACT_TRAINING_TRAININGSAMPLESET_H_
18 
19 #include "bitvector.h"
20 #include "indexmapbidi.h"
21 #include "matrix.h"
22 #include "shapetable.h"
23 #include "trainingsample.h"
24 
25 namespace tesseract {
26 
27 class UNICHARSET;
28 struct FontInfo;
29 class FontInfoTable;
30 class IntFeatureMap;
31 class IntFeatureSpace;
32 class TrainingSample;
33 struct UnicharAndFonts;
34 
35 // Collection of TrainingSample used for training or testing a classifier.
36 // Provides several useful methods to operate on the collection as a whole,
37 // including outlier detection and deletion, providing access by font and
38 // class, finding the canonical sample, finding the "cloud" features (OR of
39 // all features in all samples), replication of samples, caching of distance
40 // metrics.
41 class TrainingSampleSet {
42 public:
43   explicit TrainingSampleSet(const FontInfoTable &fontinfo_table);
44   ~TrainingSampleSet();
45 
46   // Writes to the given file. Returns false in case of error.
47   bool Serialize(FILE *fp) const;
48   // Reads from the given file. Returns false in case of error.
49   // If swap is true, assumes a big/little-endian swap is needed.
50   bool DeSerialize(bool swap, FILE *fp);
51 
52   // Accessors
num_samples()53   int num_samples() const {
54     return samples_.size();
55   }
num_raw_samples()56   int num_raw_samples() const {
57     return num_raw_samples_;
58   }
NumFonts()59   int NumFonts() const {
60     return font_id_map_.SparseSize();
61   }
unicharset()62   const UNICHARSET &unicharset() const {
63     return unicharset_;
64   }
charsetsize()65   int charsetsize() const {
66     return unicharset_size_;
67   }
fontinfo_table()68   const FontInfoTable &fontinfo_table() const {
69     return fontinfo_table_;
70   }
71 
72   // Loads an initial unicharset, or sets one up if the file cannot be read.
73   void LoadUnicharset(const char *filename);
74 
75   // Adds a character sample to this sample set.
76   // If the unichar is not already in the local unicharset, it is added.
77   // Returns the unichar_id of the added sample, from the local unicharset.
78   int AddSample(const char *unichar, TrainingSample *sample);
79   // Adds a character sample to this sample set with the given unichar_id,
80   // which must correspond to the local unicharset (in this).
81   void AddSample(int unichar_id, TrainingSample *sample);
82 
83   // Returns the number of samples for the given font,class pair.
84   // If randomize is true, returns the number of samples accessible
85   // with randomizing on. (Increases the number of samples if small.)
86   // OrganizeByFontAndClass must have been already called.
87   int NumClassSamples(int font_id, int class_id, bool randomize) const;
88 
89   // Gets a sample by its index.
90   const TrainingSample *GetSample(int index) const;
91 
92   // Gets a sample by its font, class, index.
93   // OrganizeByFontAndClass must have been already called.
94   const TrainingSample *GetSample(int font_id, int class_id, int index) const;
95 
96   // Get a sample by its font, class, index. Does not randomize.
97   // OrganizeByFontAndClass must have been already called.
98   TrainingSample *MutableSample(int font_id, int class_id, int index);
99 
100   // Returns a string debug representation of the given sample:
101   // font, unichar_str, bounding box, page.
102   std::string SampleToString(const TrainingSample &sample) const;
103 
104   // Gets the combined set of features used by all the samples of the given
105   // font/class combination.
106   const BitVector &GetCloudFeatures(int font_id, int class_id) const;
107   // Gets the indexed features of the canonical sample of the given
108   // font/class combination.
109   const std::vector<int> &GetCanonicalFeatures(int font_id, int class_id) const;
110 
111   // Returns the distance between the given UniCharAndFonts pair.
112   // If matched_fonts, only matching fonts, are considered, unless that yields
113   // the empty set.
114   // OrganizeByFontAndClass must have been already called.
115   float UnicharDistance(const UnicharAndFonts &uf1, const UnicharAndFonts &uf2, bool matched_fonts,
116                         const IntFeatureMap &feature_map);
117 
118   // Returns the distance between the given pair of font/class pairs.
119   // Finds in cache or computes and caches.
120   // OrganizeByFontAndClass must have been already called.
121   float ClusterDistance(int font_id1, int class_id1, int font_id2, int class_id2,
122                         const IntFeatureMap &feature_map);
123 
124   // Computes the distance between the given pair of font/class pairs.
125   float ComputeClusterDistance(int font_id1, int class_id1, int font_id2, int class_id2,
126                                const IntFeatureMap &feature_map) const;
127 
128   // Returns the number of canonical features of font/class 2 for which
129   // neither the feature nor any of its near neighbors occurs in the cloud
130   // of font/class 1. Each such feature is a reliable separation between
131   // the classes, ASSUMING that the canonical sample is sufficiently
132   // representative that every sample has a feature near that particular
133   // feature. To check that this is so on the fly would be prohibitively
134   // expensive, but it might be possible to pre-qualify the canonical features
135   // to include only those for which this assumption is true.
136   // ComputeCanonicalFeatures and ComputeCloudFeatures must have been called
137   // first, or the results will be nonsense.
138   int ReliablySeparable(int font_id1, int class_id1, int font_id2, int class_id2,
139                         const IntFeatureMap &feature_map, bool thorough) const;
140 
141   // Returns the total index of the requested sample.
142   // OrganizeByFontAndClass must have been already called.
143   int GlobalSampleIndex(int font_id, int class_id, int index) const;
144 
145   // Gets the canonical sample for the given font, class pair.
146   // ComputeCanonicalSamples must have been called first.
147   const TrainingSample *GetCanonicalSample(int font_id, int class_id) const;
148   // Gets the max distance for the given canonical sample.
149   // ComputeCanonicalSamples must have been called first.
150   float GetCanonicalDist(int font_id, int class_id) const;
151 
152   // Returns a mutable pointer to the sample with the given index.
mutable_sample(int index)153   TrainingSample *mutable_sample(int index) {
154     return samples_[index];
155   }
156   // Gets ownership of the sample with the given index, removing it from this.
extract_sample(int index)157   TrainingSample *extract_sample(int index) {
158     TrainingSample *sample = samples_[index];
159     samples_[index] = nullptr;
160     return sample;
161   }
162 
163   // Generates indexed features for all samples with the supplied feature_space.
164   void IndexFeatures(const IntFeatureSpace &feature_space);
165 
166   // Marks the given sample for deletion.
167   // Deletion is actually completed by DeleteDeadSamples.
168   void KillSample(TrainingSample *sample);
169 
170   // Deletes all samples with a negative sample index marked by KillSample.
171   // Must be called before OrganizeByFontAndClass, and OrganizeByFontAndClass
172   // must be called after as the samples have been renumbered.
173   void DeleteDeadSamples();
174 
175   // Construct an array to access the samples by font,class pair.
176   void OrganizeByFontAndClass();
177 
178   // Constructs the font_id_map_ which maps real font_ids (sparse) to a compact
179   // index for the font_class_array_.
180   void SetupFontIdMap();
181 
182   // Finds the sample for each font, class pair that has least maximum
183   // distance to all the other samples of the same font, class.
184   // OrganizeByFontAndClass must have been already called.
185   void ComputeCanonicalSamples(const IntFeatureMap &map, bool debug);
186 
187   // Replicates the samples to a minimum frequency defined by
188   // 2 * kSampleRandomSize, or for larger counts duplicates all samples.
189   // After replication, the replicated samples are perturbed slightly, but
190   // in a predictable and repeatable way.
191   // Use after OrganizeByFontAndClass().
192   void ReplicateAndRandomizeSamples();
193 
194   // Caches the indexed features of the canonical samples.
195   // ComputeCanonicalSamples must have been already called.
196   void ComputeCanonicalFeatures();
197   // Computes the combined set of features used by all the samples of each
198   // font/class combination. Use after ReplicateAndRandomizeSamples.
199   void ComputeCloudFeatures(int feature_space_size);
200 
201   // Adds all fonts of the given class to the shape.
202   void AddAllFontsForClass(int class_id, Shape *shape) const;
203 
204   // Display the samples with the given indexed feature that also match
205   // the given shape.
206   void DisplaySamplesWithFeature(int f_index, const Shape &shape,
207                                  const IntFeatureSpace &feature_space, ScrollView::Color color,
208                                  ScrollView *window) const;
209 
210 private:
211   // Struct to store a triplet of unichar, font, distance in the distance cache.
212   struct FontClassDistance {
213     int unichar_id;
214     int font_id; // Real font id.
215     float distance;
216   };
217   // Simple struct to store information related to each font/class combination.
218   struct FontClassInfo {
219     FontClassInfo();
220 
221     // Writes to the given file. Returns false in case of error.
222     bool Serialize(FILE *fp) const;
223     // Reads from the given file. Returns false in case of error.
224     // If swap is true, assumes a big/little-endian swap is needed.
225     bool DeSerialize(bool swap, FILE *fp);
226 
227     // Number of raw samples.
228     int32_t num_raw_samples;
229     // Index of the canonical sample.
230     int32_t canonical_sample;
231     // Max distance of the canonical sample from any other.
232     float canonical_dist;
233     // Sample indices for the samples, including replicated.
234     std::vector<int32_t> samples;
235 
236     // Non-serialized cache data.
237     // Indexed features of the canonical sample.
238     std::vector<int> canonical_features;
239     // The mapped features of all the samples.
240     BitVector cloud_features;
241 
242     // Caches for ClusterDistance.
243     // Caches for other fonts but matching this unichar. -1 indicates not set.
244     // Indexed by compact font index from font_id_map_.
245     std::vector<float> font_distance_cache;
246     // Caches for other unichars but matching this font. -1 indicates not set.
247     std::vector<float> unichar_distance_cache;
248     // Cache for the rest (non matching font and unichar.)
249     // A cache of distances computed by ReliablySeparable.
250     std::vector<FontClassDistance> distance_cache;
251   };
252 
253   std::vector<TrainingSample *> samples_;
254   // Number of samples before replication/randomization.
255   int num_raw_samples_;
256   // Character set we are training for.
257   UNICHARSET unicharset_;
258   // Character set size to which the 2-d arrays below refer.
259   int unicharset_size_;
260   // Map to allow the font_class_array_ below to be compact.
261   // The sparse space is the real font_id, used in samples_ .
262   // The compact space is an index to font_class_array_
263   IndexMapBiDi font_id_map_;
264   // A 2-d array of FontClassInfo holding information related to each
265   // (font_id, class_id) pair.
266   GENERIC_2D_ARRAY<FontClassInfo> *font_class_array_;
267 
268   // Reference to the fontinfo_table_ in MasterTrainer. Provides names
269   // for font_ids in the samples. Not serialized!
270   const FontInfoTable &fontinfo_table_;
271 };
272 
273 } // namespace tesseract.
274 
275 #endif // TRAININGSAMPLESETSET_H_
276