1 // Copyright 2010 Google Inc. All Rights Reserved. 2 // Author: rays@google.com (Ray Smith) 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 // 14 /////////////////////////////////////////////////////////////////////// 15 16 #ifndef TESSERACT_TRAINING_TRAININGSAMPLESET_H_ 17 #define TESSERACT_TRAINING_TRAININGSAMPLESET_H_ 18 19 #include "bitvector.h" 20 #include "indexmapbidi.h" 21 #include "matrix.h" 22 #include "shapetable.h" 23 #include "trainingsample.h" 24 25 namespace tesseract { 26 27 class UNICHARSET; 28 struct FontInfo; 29 class FontInfoTable; 30 class IntFeatureMap; 31 class IntFeatureSpace; 32 class TrainingSample; 33 struct UnicharAndFonts; 34 35 // Collection of TrainingSample used for training or testing a classifier. 36 // Provides several useful methods to operate on the collection as a whole, 37 // including outlier detection and deletion, providing access by font and 38 // class, finding the canonical sample, finding the "cloud" features (OR of 39 // all features in all samples), replication of samples, caching of distance 40 // metrics. 41 class TrainingSampleSet { 42 public: 43 explicit TrainingSampleSet(const FontInfoTable &fontinfo_table); 44 ~TrainingSampleSet(); 45 46 // Writes to the given file. Returns false in case of error. 47 bool Serialize(FILE *fp) const; 48 // Reads from the given file. Returns false in case of error. 49 // If swap is true, assumes a big/little-endian swap is needed. 50 bool DeSerialize(bool swap, FILE *fp); 51 52 // Accessors num_samples()53 int num_samples() const { 54 return samples_.size(); 55 } num_raw_samples()56 int num_raw_samples() const { 57 return num_raw_samples_; 58 } NumFonts()59 int NumFonts() const { 60 return font_id_map_.SparseSize(); 61 } unicharset()62 const UNICHARSET &unicharset() const { 63 return unicharset_; 64 } charsetsize()65 int charsetsize() const { 66 return unicharset_size_; 67 } fontinfo_table()68 const FontInfoTable &fontinfo_table() const { 69 return fontinfo_table_; 70 } 71 72 // Loads an initial unicharset, or sets one up if the file cannot be read. 73 void LoadUnicharset(const char *filename); 74 75 // Adds a character sample to this sample set. 76 // If the unichar is not already in the local unicharset, it is added. 77 // Returns the unichar_id of the added sample, from the local unicharset. 78 int AddSample(const char *unichar, TrainingSample *sample); 79 // Adds a character sample to this sample set with the given unichar_id, 80 // which must correspond to the local unicharset (in this). 81 void AddSample(int unichar_id, TrainingSample *sample); 82 83 // Returns the number of samples for the given font,class pair. 84 // If randomize is true, returns the number of samples accessible 85 // with randomizing on. (Increases the number of samples if small.) 86 // OrganizeByFontAndClass must have been already called. 87 int NumClassSamples(int font_id, int class_id, bool randomize) const; 88 89 // Gets a sample by its index. 90 const TrainingSample *GetSample(int index) const; 91 92 // Gets a sample by its font, class, index. 93 // OrganizeByFontAndClass must have been already called. 94 const TrainingSample *GetSample(int font_id, int class_id, int index) const; 95 96 // Get a sample by its font, class, index. Does not randomize. 97 // OrganizeByFontAndClass must have been already called. 98 TrainingSample *MutableSample(int font_id, int class_id, int index); 99 100 // Returns a string debug representation of the given sample: 101 // font, unichar_str, bounding box, page. 102 std::string SampleToString(const TrainingSample &sample) const; 103 104 // Gets the combined set of features used by all the samples of the given 105 // font/class combination. 106 const BitVector &GetCloudFeatures(int font_id, int class_id) const; 107 // Gets the indexed features of the canonical sample of the given 108 // font/class combination. 109 const std::vector<int> &GetCanonicalFeatures(int font_id, int class_id) const; 110 111 // Returns the distance between the given UniCharAndFonts pair. 112 // If matched_fonts, only matching fonts, are considered, unless that yields 113 // the empty set. 114 // OrganizeByFontAndClass must have been already called. 115 float UnicharDistance(const UnicharAndFonts &uf1, const UnicharAndFonts &uf2, bool matched_fonts, 116 const IntFeatureMap &feature_map); 117 118 // Returns the distance between the given pair of font/class pairs. 119 // Finds in cache or computes and caches. 120 // OrganizeByFontAndClass must have been already called. 121 float ClusterDistance(int font_id1, int class_id1, int font_id2, int class_id2, 122 const IntFeatureMap &feature_map); 123 124 // Computes the distance between the given pair of font/class pairs. 125 float ComputeClusterDistance(int font_id1, int class_id1, int font_id2, int class_id2, 126 const IntFeatureMap &feature_map) const; 127 128 // Returns the number of canonical features of font/class 2 for which 129 // neither the feature nor any of its near neighbors occurs in the cloud 130 // of font/class 1. Each such feature is a reliable separation between 131 // the classes, ASSUMING that the canonical sample is sufficiently 132 // representative that every sample has a feature near that particular 133 // feature. To check that this is so on the fly would be prohibitively 134 // expensive, but it might be possible to pre-qualify the canonical features 135 // to include only those for which this assumption is true. 136 // ComputeCanonicalFeatures and ComputeCloudFeatures must have been called 137 // first, or the results will be nonsense. 138 int ReliablySeparable(int font_id1, int class_id1, int font_id2, int class_id2, 139 const IntFeatureMap &feature_map, bool thorough) const; 140 141 // Returns the total index of the requested sample. 142 // OrganizeByFontAndClass must have been already called. 143 int GlobalSampleIndex(int font_id, int class_id, int index) const; 144 145 // Gets the canonical sample for the given font, class pair. 146 // ComputeCanonicalSamples must have been called first. 147 const TrainingSample *GetCanonicalSample(int font_id, int class_id) const; 148 // Gets the max distance for the given canonical sample. 149 // ComputeCanonicalSamples must have been called first. 150 float GetCanonicalDist(int font_id, int class_id) const; 151 152 // Returns a mutable pointer to the sample with the given index. mutable_sample(int index)153 TrainingSample *mutable_sample(int index) { 154 return samples_[index]; 155 } 156 // Gets ownership of the sample with the given index, removing it from this. extract_sample(int index)157 TrainingSample *extract_sample(int index) { 158 TrainingSample *sample = samples_[index]; 159 samples_[index] = nullptr; 160 return sample; 161 } 162 163 // Generates indexed features for all samples with the supplied feature_space. 164 void IndexFeatures(const IntFeatureSpace &feature_space); 165 166 // Marks the given sample for deletion. 167 // Deletion is actually completed by DeleteDeadSamples. 168 void KillSample(TrainingSample *sample); 169 170 // Deletes all samples with a negative sample index marked by KillSample. 171 // Must be called before OrganizeByFontAndClass, and OrganizeByFontAndClass 172 // must be called after as the samples have been renumbered. 173 void DeleteDeadSamples(); 174 175 // Construct an array to access the samples by font,class pair. 176 void OrganizeByFontAndClass(); 177 178 // Constructs the font_id_map_ which maps real font_ids (sparse) to a compact 179 // index for the font_class_array_. 180 void SetupFontIdMap(); 181 182 // Finds the sample for each font, class pair that has least maximum 183 // distance to all the other samples of the same font, class. 184 // OrganizeByFontAndClass must have been already called. 185 void ComputeCanonicalSamples(const IntFeatureMap &map, bool debug); 186 187 // Replicates the samples to a minimum frequency defined by 188 // 2 * kSampleRandomSize, or for larger counts duplicates all samples. 189 // After replication, the replicated samples are perturbed slightly, but 190 // in a predictable and repeatable way. 191 // Use after OrganizeByFontAndClass(). 192 void ReplicateAndRandomizeSamples(); 193 194 // Caches the indexed features of the canonical samples. 195 // ComputeCanonicalSamples must have been already called. 196 void ComputeCanonicalFeatures(); 197 // Computes the combined set of features used by all the samples of each 198 // font/class combination. Use after ReplicateAndRandomizeSamples. 199 void ComputeCloudFeatures(int feature_space_size); 200 201 // Adds all fonts of the given class to the shape. 202 void AddAllFontsForClass(int class_id, Shape *shape) const; 203 204 // Display the samples with the given indexed feature that also match 205 // the given shape. 206 void DisplaySamplesWithFeature(int f_index, const Shape &shape, 207 const IntFeatureSpace &feature_space, ScrollView::Color color, 208 ScrollView *window) const; 209 210 private: 211 // Struct to store a triplet of unichar, font, distance in the distance cache. 212 struct FontClassDistance { 213 int unichar_id; 214 int font_id; // Real font id. 215 float distance; 216 }; 217 // Simple struct to store information related to each font/class combination. 218 struct FontClassInfo { 219 FontClassInfo(); 220 221 // Writes to the given file. Returns false in case of error. 222 bool Serialize(FILE *fp) const; 223 // Reads from the given file. Returns false in case of error. 224 // If swap is true, assumes a big/little-endian swap is needed. 225 bool DeSerialize(bool swap, FILE *fp); 226 227 // Number of raw samples. 228 int32_t num_raw_samples; 229 // Index of the canonical sample. 230 int32_t canonical_sample; 231 // Max distance of the canonical sample from any other. 232 float canonical_dist; 233 // Sample indices for the samples, including replicated. 234 std::vector<int32_t> samples; 235 236 // Non-serialized cache data. 237 // Indexed features of the canonical sample. 238 std::vector<int> canonical_features; 239 // The mapped features of all the samples. 240 BitVector cloud_features; 241 242 // Caches for ClusterDistance. 243 // Caches for other fonts but matching this unichar. -1 indicates not set. 244 // Indexed by compact font index from font_id_map_. 245 std::vector<float> font_distance_cache; 246 // Caches for other unichars but matching this font. -1 indicates not set. 247 std::vector<float> unichar_distance_cache; 248 // Cache for the rest (non matching font and unichar.) 249 // A cache of distances computed by ReliablySeparable. 250 std::vector<FontClassDistance> distance_cache; 251 }; 252 253 std::vector<TrainingSample *> samples_; 254 // Number of samples before replication/randomization. 255 int num_raw_samples_; 256 // Character set we are training for. 257 UNICHARSET unicharset_; 258 // Character set size to which the 2-d arrays below refer. 259 int unicharset_size_; 260 // Map to allow the font_class_array_ below to be compact. 261 // The sparse space is the real font_id, used in samples_ . 262 // The compact space is an index to font_class_array_ 263 IndexMapBiDi font_id_map_; 264 // A 2-d array of FontClassInfo holding information related to each 265 // (font_id, class_id) pair. 266 GENERIC_2D_ARRAY<FontClassInfo> *font_class_array_; 267 268 // Reference to the fontinfo_table_ in MasterTrainer. Provides names 269 // for font_ids in the samples. Not serialized! 270 const FontInfoTable &fontinfo_table_; 271 }; 272 273 } // namespace tesseract. 274 275 #endif // TRAININGSAMPLESETSET_H_ 276