1 // Copyright 2010 Google Inc. All Rights Reserved.
2 // Author: rays@google.com (Ray Smith)
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 //
14 ///////////////////////////////////////////////////////////////////////
15 
16 #ifdef HAVE_CONFIG_H
17 #  include "config_auto.h"
18 #endif
19 
20 #include <algorithm>
21 
22 #include <allheaders.h>
23 #include "boxread.h"
24 #include "fontinfo.h"
25 //#include "helpers.h"
26 #include "indexmapbidi.h"
27 #include "intfeaturedist.h"
28 #include "intfeaturemap.h"
29 #include "intfeaturespace.h"
30 #include "shapetable.h"
31 #include "trainingsample.h"
32 #include "trainingsampleset.h"
33 #include "unicity_table.h"
34 
35 namespace tesseract {
36 
37 const int kTestChar = -1; // 37;
38 // Max number of distances to compute the squared way
39 const int kSquareLimit = 25;
40 // Prime numbers for subsampling distances.
41 const int kPrime1 = 17;
42 const int kPrime2 = 13;
43 
FontClassInfo()44 TrainingSampleSet::FontClassInfo::FontClassInfo()
45     : num_raw_samples(0), canonical_sample(-1), canonical_dist(0.0f) {}
46 
47 // Writes to the given file. Returns false in case of error.
Serialize(FILE * fp) const48 bool TrainingSampleSet::FontClassInfo::Serialize(FILE *fp) const {
49   if (fwrite(&num_raw_samples, sizeof(num_raw_samples), 1, fp) != 1) {
50     return false;
51   }
52   if (fwrite(&canonical_sample, sizeof(canonical_sample), 1, fp) != 1) {
53     return false;
54   }
55   if (fwrite(&canonical_dist, sizeof(canonical_dist), 1, fp) != 1) {
56     return false;
57   }
58   if (!::tesseract::Serialize(fp, samples)) {
59     return false;
60   }
61   return true;
62 }
63 // Reads from the given file. Returns false in case of error.
64 // If swap is true, assumes a big/little-endian swap is needed.
DeSerialize(bool swap,FILE * fp)65 bool TrainingSampleSet::FontClassInfo::DeSerialize(bool swap, FILE *fp) {
66   if (fread(&num_raw_samples, sizeof(num_raw_samples), 1, fp) != 1) {
67     return false;
68   }
69   if (fread(&canonical_sample, sizeof(canonical_sample), 1, fp) != 1) {
70     return false;
71   }
72   if (fread(&canonical_dist, sizeof(canonical_dist), 1, fp) != 1) {
73     return false;
74   }
75   if (!::tesseract::DeSerialize(swap, fp, samples)) {
76     return false;
77   }
78   if (swap) {
79     ReverseN(&num_raw_samples, sizeof(num_raw_samples));
80     ReverseN(&canonical_sample, sizeof(canonical_sample));
81     ReverseN(&canonical_dist, sizeof(canonical_dist));
82   }
83   return true;
84 }
85 
TrainingSampleSet(const FontInfoTable & font_table)86 TrainingSampleSet::TrainingSampleSet(const FontInfoTable &font_table)
87     : num_raw_samples_(0)
88     , unicharset_size_(0)
89     , font_class_array_(nullptr)
90     , fontinfo_table_(font_table) {}
91 
~TrainingSampleSet()92 TrainingSampleSet::~TrainingSampleSet() {
93   for (auto sample : samples_) {
94     delete sample;
95   }
96   delete font_class_array_;
97 }
98 
99 // Writes to the given file. Returns false in case of error.
Serialize(FILE * fp) const100 bool TrainingSampleSet::Serialize(FILE *fp) const {
101   if (!tesseract::Serialize(fp, samples_)) {
102     return false;
103   }
104   if (!unicharset_.save_to_file(fp)) {
105     return false;
106   }
107   if (!font_id_map_.Serialize(fp)) {
108     return false;
109   }
110   int8_t not_null = font_class_array_ != nullptr;
111   if (fwrite(&not_null, sizeof(not_null), 1, fp) != 1) {
112     return false;
113   }
114   if (not_null) {
115     if (!font_class_array_->SerializeClasses(fp)) {
116       return false;
117     }
118   }
119   return true;
120 }
121 
122 // Reads from the given file. Returns false in case of error.
123 // If swap is true, assumes a big/little-endian swap is needed.
DeSerialize(bool swap,FILE * fp)124 bool TrainingSampleSet::DeSerialize(bool swap, FILE *fp) {
125   if (!tesseract::DeSerialize(swap, fp, samples_)) {
126     return false;
127   }
128   num_raw_samples_ = samples_.size();
129   if (!unicharset_.load_from_file(fp)) {
130     return false;
131   }
132   if (!font_id_map_.DeSerialize(swap, fp)) {
133     return false;
134   }
135   delete font_class_array_;
136   font_class_array_ = nullptr;
137   int8_t not_null;
138   if (fread(&not_null, sizeof(not_null), 1, fp) != 1) {
139     return false;
140   }
141   if (not_null) {
142     FontClassInfo empty;
143     font_class_array_ = new GENERIC_2D_ARRAY<FontClassInfo>(1, 1, empty);
144     if (!font_class_array_->DeSerializeClasses(swap, fp)) {
145       return false;
146     }
147   }
148   unicharset_size_ = unicharset_.size();
149   return true;
150 }
151 
152 // Load an initial unicharset, or set one up if the file cannot be read.
LoadUnicharset(const char * filename)153 void TrainingSampleSet::LoadUnicharset(const char *filename) {
154   if (!unicharset_.load_from_file(filename)) {
155     tprintf(
156         "Failed to load unicharset from file %s\n"
157         "Building unicharset from scratch...\n",
158         filename);
159     unicharset_.clear();
160     // Add special characters as they were removed by the clear.
161     UNICHARSET empty;
162     unicharset_.AppendOtherUnicharset(empty);
163   }
164   unicharset_size_ = unicharset_.size();
165 }
166 
167 // Adds a character sample to this sample set.
168 // If the unichar is not already in the local unicharset, it is added.
169 // Returns the unichar_id of the added sample, from the local unicharset.
AddSample(const char * unichar,TrainingSample * sample)170 int TrainingSampleSet::AddSample(const char *unichar, TrainingSample *sample) {
171   if (!unicharset_.contains_unichar(unichar)) {
172     unicharset_.unichar_insert(unichar);
173     if (unicharset_.size() > MAX_NUM_CLASSES) {
174       tprintf(
175           "Error: Size of unicharset in TrainingSampleSet::AddSample is "
176           "greater than MAX_NUM_CLASSES\n");
177       return -1;
178     }
179   }
180   UNICHAR_ID char_id = unicharset_.unichar_to_id(unichar);
181   AddSample(char_id, sample);
182   return char_id;
183 }
184 
185 // Adds a character sample to this sample set with the given unichar_id,
186 // which must correspond to the local unicharset (in this).
AddSample(int unichar_id,TrainingSample * sample)187 void TrainingSampleSet::AddSample(int unichar_id, TrainingSample *sample) {
188   sample->set_class_id(unichar_id);
189   samples_.push_back(sample);
190   num_raw_samples_ = samples_.size();
191   unicharset_size_ = unicharset_.size();
192 }
193 
194 // Returns the number of samples for the given font,class pair.
195 // If randomize is true, returns the number of samples accessible
196 // with randomizing on. (Increases the number of samples if small.)
197 // OrganizeByFontAndClass must have been already called.
NumClassSamples(int font_id,int class_id,bool randomize) const198 int TrainingSampleSet::NumClassSamples(int font_id, int class_id, bool randomize) const {
199   ASSERT_HOST(font_class_array_ != nullptr);
200   if (font_id < 0 || class_id < 0 || font_id >= font_id_map_.SparseSize() ||
201       class_id >= unicharset_size_) {
202     // There are no samples because the font or class doesn't exist.
203     return 0;
204   }
205   int font_index = font_id_map_.SparseToCompact(font_id);
206   if (font_index < 0) {
207     return 0; // The font has no samples.
208   }
209   if (randomize) {
210     return (*font_class_array_)(font_index, class_id).samples.size();
211   } else {
212     return (*font_class_array_)(font_index, class_id).num_raw_samples;
213   }
214 }
215 
216 // Gets a sample by its index.
GetSample(int index) const217 const TrainingSample *TrainingSampleSet::GetSample(int index) const {
218   return samples_[index];
219 }
220 
221 // Gets a sample by its font, class, index.
222 // OrganizeByFontAndClass must have been already called.
GetSample(int font_id,int class_id,int index) const223 const TrainingSample *TrainingSampleSet::GetSample(int font_id, int class_id, int index) const {
224   ASSERT_HOST(font_class_array_ != nullptr);
225   int font_index = font_id_map_.SparseToCompact(font_id);
226   if (font_index < 0) {
227     return nullptr;
228   }
229   int sample_index = (*font_class_array_)(font_index, class_id).samples[index];
230   return samples_[sample_index];
231 }
232 
233 // Get a sample by its font, class, index. Does not randomize.
234 // OrganizeByFontAndClass must have been already called.
MutableSample(int font_id,int class_id,int index)235 TrainingSample *TrainingSampleSet::MutableSample(int font_id, int class_id, int index) {
236   ASSERT_HOST(font_class_array_ != nullptr);
237   int font_index = font_id_map_.SparseToCompact(font_id);
238   if (font_index < 0) {
239     return nullptr;
240   }
241   int sample_index = (*font_class_array_)(font_index, class_id).samples[index];
242   return samples_[sample_index];
243 }
244 
245 // Returns a string debug representation of the given sample:
246 // font, unichar_str, bounding box, page.
SampleToString(const TrainingSample & sample) const247 std::string TrainingSampleSet::SampleToString(const TrainingSample &sample) const {
248   std::string boxfile_str;
249   MakeBoxFileStr(unicharset_.id_to_unichar(sample.class_id()), sample.bounding_box(),
250                  sample.page_num(), boxfile_str);
251   return std::string(fontinfo_table_.at(sample.font_id()).name) + " " + boxfile_str;
252 }
253 
254 // Gets the combined set of features used by all the samples of the given
255 // font/class combination.
GetCloudFeatures(int font_id,int class_id) const256 const BitVector &TrainingSampleSet::GetCloudFeatures(int font_id, int class_id) const {
257   int font_index = font_id_map_.SparseToCompact(font_id);
258   ASSERT_HOST(font_index >= 0);
259   return (*font_class_array_)(font_index, class_id).cloud_features;
260 }
261 // Gets the indexed features of the canonical sample of the given
262 // font/class combination.
GetCanonicalFeatures(int font_id,int class_id) const263 const std::vector<int> &TrainingSampleSet::GetCanonicalFeatures(int font_id, int class_id) const {
264   int font_index = font_id_map_.SparseToCompact(font_id);
265   ASSERT_HOST(font_index >= 0);
266   return (*font_class_array_)(font_index, class_id).canonical_features;
267 }
268 
269 // Returns the distance between the given UniCharAndFonts pair.
270 // If matched_fonts, only matching fonts, are considered, unless that yields
271 // the empty set.
272 // OrganizeByFontAndClass must have been already called.
UnicharDistance(const UnicharAndFonts & uf1,const UnicharAndFonts & uf2,bool matched_fonts,const IntFeatureMap & feature_map)273 float TrainingSampleSet::UnicharDistance(const UnicharAndFonts &uf1, const UnicharAndFonts &uf2,
274                                          bool matched_fonts, const IntFeatureMap &feature_map) {
275   int num_fonts1 = uf1.font_ids.size();
276   int c1 = uf1.unichar_id;
277   int num_fonts2 = uf2.font_ids.size();
278   int c2 = uf2.unichar_id;
279   double dist_sum = 0.0;
280   int dist_count = 0;
281   const bool debug = false;
282   if (matched_fonts) {
283     // Compute distances only where fonts match.
284     for (int i = 0; i < num_fonts1; ++i) {
285       int f1 = uf1.font_ids[i];
286       for (int j = 0; j < num_fonts2; ++j) {
287         int f2 = uf2.font_ids[j];
288         if (f1 == f2) {
289           dist_sum += ClusterDistance(f1, c1, f2, c2, feature_map);
290           ++dist_count;
291         }
292       }
293     }
294   } else if (num_fonts1 * num_fonts2 <= kSquareLimit) {
295     // Small enough sets to compute all the distances.
296     for (int i = 0; i < num_fonts1; ++i) {
297       int f1 = uf1.font_ids[i];
298       for (int j = 0; j < num_fonts2; ++j) {
299         int f2 = uf2.font_ids[j];
300         dist_sum += ClusterDistance(f1, c1, f2, c2, feature_map);
301         if (debug) {
302           tprintf("Cluster dist %d %d %d %d = %g\n", f1, c1, f2, c2,
303                   ClusterDistance(f1, c1, f2, c2, feature_map));
304         }
305         ++dist_count;
306       }
307     }
308   } else {
309     // Subsample distances, using the largest set once, and stepping through
310     // the smaller set so as to ensure that all the pairs are different.
311     int increment = kPrime1 != num_fonts2 ? kPrime1 : kPrime2;
312     int index = 0;
313     int num_samples = std::max(num_fonts1, num_fonts2);
314     for (int i = 0; i < num_samples; ++i, index += increment) {
315       int f1 = uf1.font_ids[i % num_fonts1];
316       int f2 = uf2.font_ids[index % num_fonts2];
317       if (debug) {
318         tprintf("Cluster dist %d %d %d %d = %g\n", f1, c1, f2, c2,
319                 ClusterDistance(f1, c1, f2, c2, feature_map));
320       }
321       dist_sum += ClusterDistance(f1, c1, f2, c2, feature_map);
322       ++dist_count;
323     }
324   }
325   if (dist_count == 0) {
326     if (matched_fonts) {
327       return UnicharDistance(uf1, uf2, false, feature_map);
328     }
329     return 0.0f;
330   }
331   return dist_sum / dist_count;
332 }
333 
334 // Returns the distance between the given pair of font/class pairs.
335 // Finds in cache or computes and caches.
336 // OrganizeByFontAndClass must have been already called.
ClusterDistance(int font_id1,int class_id1,int font_id2,int class_id2,const IntFeatureMap & feature_map)337 float TrainingSampleSet::ClusterDistance(int font_id1, int class_id1, int font_id2, int class_id2,
338                                          const IntFeatureMap &feature_map) {
339   ASSERT_HOST(font_class_array_ != nullptr);
340   int font_index1 = font_id_map_.SparseToCompact(font_id1);
341   int font_index2 = font_id_map_.SparseToCompact(font_id2);
342   if (font_index1 < 0 || font_index2 < 0) {
343     return 0.0f;
344   }
345   FontClassInfo &fc_info = (*font_class_array_)(font_index1, class_id1);
346   if (font_id1 == font_id2) {
347     // Special case cache for speed.
348     if (fc_info.unichar_distance_cache.empty()) {
349       fc_info.unichar_distance_cache.resize(unicharset_size_, -1.0f);
350     }
351     if (fc_info.unichar_distance_cache[class_id2] < 0) {
352       // Distance has to be calculated.
353       float result = ComputeClusterDistance(font_id1, class_id1, font_id2, class_id2, feature_map);
354       fc_info.unichar_distance_cache[class_id2] = result;
355       // Copy to the symmetric cache entry.
356       FontClassInfo &fc_info2 = (*font_class_array_)(font_index2, class_id2);
357       if (fc_info2.unichar_distance_cache.empty()) {
358         fc_info2.unichar_distance_cache.resize(unicharset_size_, -1.0f);
359       }
360       fc_info2.unichar_distance_cache[class_id1] = result;
361     }
362     return fc_info.unichar_distance_cache[class_id2];
363   } else if (class_id1 == class_id2) {
364     // Another special-case cache for equal class-id.
365     if (fc_info.font_distance_cache.empty()) {
366       fc_info.font_distance_cache.resize(font_id_map_.CompactSize(), -1.0f);
367     }
368     if (fc_info.font_distance_cache[font_index2] < 0) {
369       // Distance has to be calculated.
370       float result = ComputeClusterDistance(font_id1, class_id1, font_id2, class_id2, feature_map);
371       fc_info.font_distance_cache[font_index2] = result;
372       // Copy to the symmetric cache entry.
373       FontClassInfo &fc_info2 = (*font_class_array_)(font_index2, class_id2);
374       if (fc_info2.font_distance_cache.empty()) {
375         fc_info2.font_distance_cache.resize(font_id_map_.CompactSize(), -1.0f);
376       }
377       fc_info2.font_distance_cache[font_index1] = result;
378     }
379     return fc_info.font_distance_cache[font_index2];
380   }
381   // Both font and class are different. Linear search for class_id2/font_id2
382   // in what is a hopefully short list of distances.
383   size_t cache_index = 0;
384   while (cache_index < fc_info.distance_cache.size() &&
385          (fc_info.distance_cache[cache_index].unichar_id != class_id2 ||
386           fc_info.distance_cache[cache_index].font_id != font_id2)) {
387     ++cache_index;
388   }
389   if (cache_index == fc_info.distance_cache.size()) {
390     // Distance has to be calculated.
391     float result = ComputeClusterDistance(font_id1, class_id1, font_id2, class_id2, feature_map);
392     FontClassDistance fc_dist = {class_id2, font_id2, result};
393     fc_info.distance_cache.push_back(fc_dist);
394     // Copy to the symmetric cache entry. We know it isn't there already, as
395     // we always copy to the symmetric entry.
396     FontClassInfo &fc_info2 = (*font_class_array_)(font_index2, class_id2);
397     fc_dist.unichar_id = class_id1;
398     fc_dist.font_id = font_id1;
399     fc_info2.distance_cache.push_back(fc_dist);
400   }
401   return fc_info.distance_cache[cache_index].distance;
402 }
403 
404 // Computes the distance between the given pair of font/class pairs.
ComputeClusterDistance(int font_id1,int class_id1,int font_id2,int class_id2,const IntFeatureMap & feature_map) const405 float TrainingSampleSet::ComputeClusterDistance(int font_id1, int class_id1, int font_id2,
406                                                 int class_id2,
407                                                 const IntFeatureMap &feature_map) const {
408   int dist = ReliablySeparable(font_id1, class_id1, font_id2, class_id2, feature_map, false);
409   dist += ReliablySeparable(font_id2, class_id2, font_id1, class_id1, feature_map, false);
410   int denominator = GetCanonicalFeatures(font_id1, class_id1).size();
411   denominator += GetCanonicalFeatures(font_id2, class_id2).size();
412   return static_cast<float>(dist) / denominator;
413 }
414 
415 // Helper to add a feature and its near neighbors to the good_features.
416 // levels indicates how many times to compute the offset features of what is
417 // already there. This is done by iteration rather than recursion.
AddNearFeatures(const IntFeatureMap & feature_map,int f,int levels,std::vector<int> * good_features)418 static void AddNearFeatures(const IntFeatureMap &feature_map, int f, int levels,
419                             std::vector<int> *good_features) {
420   int prev_num_features = 0;
421   good_features->push_back(f);
422   int num_features = 1;
423   for (int level = 0; level < levels; ++level) {
424     for (int i = prev_num_features; i < num_features; ++i) {
425       int feature = (*good_features)[i];
426       for (int dir = -kNumOffsetMaps; dir <= kNumOffsetMaps; ++dir) {
427         if (dir == 0) {
428           continue;
429         }
430         int f1 = feature_map.OffsetFeature(feature, dir);
431         if (f1 >= 0) {
432           good_features->push_back(f1);
433         }
434       }
435     }
436     prev_num_features = num_features;
437     num_features = good_features->size();
438   }
439 }
440 
441 // Returns the number of canonical features of font/class 2 for which
442 // neither the feature nor any of its near neighbors occurs in the cloud
443 // of font/class 1. Each such feature is a reliable separation between
444 // the classes, ASSUMING that the canonical sample is sufficiently
445 // representative that every sample has a feature near that particular
446 // feature. To check that this is so on the fly would be prohibitively
447 // expensive, but it might be possible to pre-qualify the canonical features
448 // to include only those for which this assumption is true.
449 // ComputeCanonicalFeatures and ComputeCloudFeatures must have been called
450 // first, or the results will be nonsense.
ReliablySeparable(int font_id1,int class_id1,int font_id2,int class_id2,const IntFeatureMap & feature_map,bool thorough) const451 int TrainingSampleSet::ReliablySeparable(int font_id1, int class_id1, int font_id2, int class_id2,
452                                          const IntFeatureMap &feature_map, bool thorough) const {
453   int result = 0;
454   const TrainingSample *sample2 = GetCanonicalSample(font_id2, class_id2);
455   if (sample2 == nullptr) {
456     return 0; // There are no canonical features.
457   }
458   const std::vector<int> &canonical2 = GetCanonicalFeatures(font_id2, class_id2);
459   const BitVector &cloud1 = GetCloudFeatures(font_id1, class_id1);
460   if (cloud1.empty()) {
461     return canonical2.size(); // There are no cloud features.
462   }
463 
464   // Find a canonical2 feature that is not in cloud1.
465   for (int feature : canonical2) {
466     if (cloud1[feature]) {
467       continue;
468     }
469     // Gather the near neighbours of f.
470     std::vector<int> good_features;
471     AddNearFeatures(feature_map, feature, 1, &good_features);
472     // Check that none of the good_features are in the cloud.
473     bool found = false;
474     for (auto good_f : good_features) {
475       if (cloud1[good_f]) {
476         found = true;
477         break;
478       }
479     }
480     if (found) {
481       continue; // Found one in the cloud.
482     }
483     ++result;
484   }
485   return result;
486 }
487 
488 // Returns the total index of the requested sample.
489 // OrganizeByFontAndClass must have been already called.
GlobalSampleIndex(int font_id,int class_id,int index) const490 int TrainingSampleSet::GlobalSampleIndex(int font_id, int class_id, int index) const {
491   ASSERT_HOST(font_class_array_ != nullptr);
492   int font_index = font_id_map_.SparseToCompact(font_id);
493   if (font_index < 0) {
494     return -1;
495   }
496   return (*font_class_array_)(font_index, class_id).samples[index];
497 }
498 
499 // Gets the canonical sample for the given font, class pair.
500 // ComputeCanonicalSamples must have been called first.
GetCanonicalSample(int font_id,int class_id) const501 const TrainingSample *TrainingSampleSet::GetCanonicalSample(int font_id, int class_id) const {
502   ASSERT_HOST(font_class_array_ != nullptr);
503   int font_index = font_id_map_.SparseToCompact(font_id);
504   if (font_index < 0) {
505     return nullptr;
506   }
507   const int sample_index = (*font_class_array_)(font_index, class_id).canonical_sample;
508   return sample_index >= 0 ? samples_[sample_index] : nullptr;
509 }
510 
511 // Gets the max distance for the given canonical sample.
512 // ComputeCanonicalSamples must have been called first.
GetCanonicalDist(int font_id,int class_id) const513 float TrainingSampleSet::GetCanonicalDist(int font_id, int class_id) const {
514   ASSERT_HOST(font_class_array_ != nullptr);
515   int font_index = font_id_map_.SparseToCompact(font_id);
516   if (font_index < 0) {
517     return 0.0f;
518   }
519   if ((*font_class_array_)(font_index, class_id).canonical_sample >= 0) {
520     return (*font_class_array_)(font_index, class_id).canonical_dist;
521   } else {
522     return 0.0f;
523   }
524 }
525 
526 // Generates indexed features for all samples with the supplied feature_space.
IndexFeatures(const IntFeatureSpace & feature_space)527 void TrainingSampleSet::IndexFeatures(const IntFeatureSpace &feature_space) {
528   for (auto &sample : samples_) {
529     sample->IndexFeatures(feature_space);
530   }
531 }
532 
533 // Marks the given sample index for deletion.
534 // Deletion is actually completed by DeleteDeadSamples.
KillSample(TrainingSample * sample)535 void TrainingSampleSet::KillSample(TrainingSample *sample) {
536   sample->set_sample_index(-1);
537 }
538 
539 // Deletes all samples with zero features marked by KillSample.
DeleteDeadSamples()540 void TrainingSampleSet::DeleteDeadSamples() {
541   using namespace std::placeholders; // for _1
542   auto old_it = samples_.begin();
543   for (; old_it < samples_.end(); ++old_it) {
544     if (*old_it == nullptr || (*old_it)->class_id() < 0) {
545       break;
546     }
547   }
548   auto new_it = old_it;
549   for (; old_it < samples_.end(); ++old_it) {
550     if (*old_it == nullptr || (*old_it)->class_id() < 0) {
551       delete *old_it;
552     } else {
553       *new_it = *old_it;
554       ++new_it;
555     }
556   }
557   samples_.resize(new_it - samples_.begin() + 1);
558   num_raw_samples_ = samples_.size();
559   // Samples must be re-organized now we have deleted a few.
560 }
561 
562 // Construct an array to access the samples by font,class pair.
OrganizeByFontAndClass()563 void TrainingSampleSet::OrganizeByFontAndClass() {
564   // Font indexes are sparse, so we used a map to compact them, so we can
565   // have an efficient 2-d array of fonts and character classes.
566   SetupFontIdMap();
567   int compact_font_size = font_id_map_.CompactSize();
568   // Get a 2-d array of generic vectors.
569   delete font_class_array_;
570   FontClassInfo empty;
571   font_class_array_ =
572       new GENERIC_2D_ARRAY<FontClassInfo>(compact_font_size, unicharset_size_, empty);
573   for (size_t s = 0; s < samples_.size(); ++s) {
574     int font_id = samples_[s]->font_id();
575     int class_id = samples_[s]->class_id();
576     if (font_id < 0 || font_id >= font_id_map_.SparseSize()) {
577       tprintf("Font id = %d/%d, class id = %d/%d on sample %zu\n", font_id,
578               font_id_map_.SparseSize(), class_id, unicharset_size_, s);
579     }
580     ASSERT_HOST(font_id >= 0 && font_id < font_id_map_.SparseSize());
581     ASSERT_HOST(class_id >= 0 && class_id < unicharset_size_);
582     int font_index = font_id_map_.SparseToCompact(font_id);
583     (*font_class_array_)(font_index, class_id).samples.push_back(s);
584   }
585   // Set the num_raw_samples member of the FontClassInfo, to set the boundary
586   // between the raw samples and the replicated ones.
587   for (int f = 0; f < compact_font_size; ++f) {
588     for (int c = 0; c < unicharset_size_; ++c) {
589       (*font_class_array_)(f, c).num_raw_samples = (*font_class_array_)(f, c).samples.size();
590     }
591   }
592   // This is the global number of samples and also marks the boundary between
593   // real and replicated samples.
594   num_raw_samples_ = samples_.size();
595 }
596 
597 // Constructs the font_id_map_ which maps real font_ids (sparse) to a compact
598 // index for the font_class_array_.
SetupFontIdMap()599 void TrainingSampleSet::SetupFontIdMap() {
600   // Number of samples for each font_id.
601   std::vector<int> font_counts;
602   for (auto &sample : samples_) {
603     const int font_id = sample->font_id();
604     while (font_id >= font_counts.size()) {
605       font_counts.push_back(0);
606     }
607     ++font_counts[font_id];
608   }
609   font_id_map_.Init(font_counts.size(), false);
610   for (size_t f = 0; f < font_counts.size(); ++f) {
611     font_id_map_.SetMap(f, font_counts[f] > 0);
612   }
613   font_id_map_.Setup();
614 }
615 
616 // Finds the sample for each font, class pair that has least maximum
617 // distance to all the other samples of the same font, class.
618 // OrganizeByFontAndClass must have been already called.
ComputeCanonicalSamples(const IntFeatureMap & map,bool debug)619 void TrainingSampleSet::ComputeCanonicalSamples(const IntFeatureMap &map, bool debug) {
620   ASSERT_HOST(font_class_array_ != nullptr);
621   IntFeatureDist f_table;
622   if (debug) {
623     tprintf("feature table size %d\n", map.sparse_size());
624   }
625   f_table.Init(&map);
626   int worst_s1 = 0;
627   int worst_s2 = 0;
628   double global_worst_dist = 0.0;
629   // Compute distances independently for each font and char index.
630   int font_size = font_id_map_.CompactSize();
631   for (int font_index = 0; font_index < font_size; ++font_index) {
632     int font_id = font_id_map_.CompactToSparse(font_index);
633     for (int c = 0; c < unicharset_size_; ++c) {
634       int samples_found = 0;
635       FontClassInfo &fcinfo = (*font_class_array_)(font_index, c);
636       if (fcinfo.samples.empty() || (kTestChar >= 0 && c != kTestChar)) {
637         fcinfo.canonical_sample = -1;
638         fcinfo.canonical_dist = 0.0f;
639         if (debug) {
640           tprintf("Skipping class %d\n", c);
641         }
642         continue;
643       }
644       // The canonical sample will be the one with the min_max_dist, which
645       // is the sample with the lowest maximum distance to all other samples.
646       double min_max_dist = 2.0;
647       // We keep track of the farthest apart pair (max_s1, max_s2) which
648       // are max_max_dist apart, so we can see how bad the variability is.
649       double max_max_dist = 0.0;
650       int max_s1 = 0;
651       int max_s2 = 0;
652       fcinfo.canonical_sample = fcinfo.samples[0];
653       fcinfo.canonical_dist = 0.0f;
654       for (auto s1 : fcinfo.samples) {
655         const std::vector<int> &features1 = samples_[s1]->indexed_features();
656         f_table.Set(features1, features1.size(), true);
657         double max_dist = 0.0;
658         // Run the full squared-order search for similar samples. It is still
659         // reasonably fast because f_table.FeatureDistance is fast, but we
660         // may have to reconsider if we start playing with too many samples
661         // of a single char/font.
662         for (int s2 : fcinfo.samples) {
663           if (samples_[s2]->class_id() != c || samples_[s2]->font_id() != font_id || s2 == s1) {
664             continue;
665           }
666           std::vector<int> features2 = samples_[s2]->indexed_features();
667           double dist = f_table.FeatureDistance(features2);
668           if (dist > max_dist) {
669             max_dist = dist;
670             if (dist > max_max_dist) {
671               max_max_dist = dist;
672               max_s1 = s1;
673               max_s2 = s2;
674             }
675           }
676         }
677         // Using Set(..., false) is far faster than re initializing, due to
678         // the sparseness of the feature space.
679         f_table.Set(features1, features1.size(), false);
680         samples_[s1]->set_max_dist(max_dist);
681         ++samples_found;
682         if (max_dist < min_max_dist) {
683           fcinfo.canonical_sample = s1;
684           fcinfo.canonical_dist = max_dist;
685         }
686         UpdateRange(max_dist, &min_max_dist, &max_max_dist);
687       }
688       if (max_max_dist > global_worst_dist) {
689         // Keep a record of the worst pair over all characters/fonts too.
690         global_worst_dist = max_max_dist;
691         worst_s1 = max_s1;
692         worst_s2 = max_s2;
693       }
694       if (debug) {
695         tprintf(
696             "Found %d samples of class %d=%s, font %d, "
697             "dist range [%g, %g], worst pair= %s, %s\n",
698             samples_found, c, unicharset_.debug_str(c).c_str(), font_index, min_max_dist,
699             max_max_dist, SampleToString(*samples_[max_s1]).c_str(),
700             SampleToString(*samples_[max_s2]).c_str());
701       }
702     }
703   }
704   if (debug) {
705     tprintf("Global worst dist = %g, between sample %d and %d\n", global_worst_dist, worst_s1,
706             worst_s2);
707   }
708 }
709 
710 // Replicates the samples to a minimum frequency defined by
711 // 2 * kSampleRandomSize, or for larger counts duplicates all samples.
712 // After replication, the replicated samples are perturbed slightly, but
713 // in a predictable and repeatable way.
714 // Use after OrganizeByFontAndClass().
ReplicateAndRandomizeSamples()715 void TrainingSampleSet::ReplicateAndRandomizeSamples() {
716   ASSERT_HOST(font_class_array_ != nullptr);
717   int font_size = font_id_map_.CompactSize();
718   for (int font_index = 0; font_index < font_size; ++font_index) {
719     for (int c = 0; c < unicharset_size_; ++c) {
720       FontClassInfo &fcinfo = (*font_class_array_)(font_index, c);
721       int sample_count = fcinfo.samples.size();
722       int min_samples = 2 * std::max(kSampleRandomSize, sample_count);
723       if (sample_count > 0 && sample_count < min_samples) {
724         int base_count = sample_count;
725         for (int base_index = 0; sample_count < min_samples; ++sample_count) {
726           int src_index = fcinfo.samples[base_index++];
727           if (base_index >= base_count) {
728             base_index = 0;
729           }
730           TrainingSample *sample =
731               samples_[src_index]->RandomizedCopy(sample_count % kSampleRandomSize);
732           int sample_index = samples_.size();
733           sample->set_sample_index(sample_index);
734           samples_.push_back(sample);
735           fcinfo.samples.push_back(sample_index);
736         }
737       }
738     }
739   }
740 }
741 
742 // Caches the indexed features of the canonical samples.
743 // ComputeCanonicalSamples must have been already called.
744 // TODO(rays) see note on ReliablySeparable and try restricting the
745 // canonical features to those that truly represent all samples.
ComputeCanonicalFeatures()746 void TrainingSampleSet::ComputeCanonicalFeatures() {
747   ASSERT_HOST(font_class_array_ != nullptr);
748   const int font_size = font_id_map_.CompactSize();
749   for (int font_index = 0; font_index < font_size; ++font_index) {
750     const int font_id = font_id_map_.CompactToSparse(font_index);
751     for (int c = 0; c < unicharset_size_; ++c) {
752       int num_samples = NumClassSamples(font_id, c, false);
753       if (num_samples == 0) {
754         continue;
755       }
756       const TrainingSample *sample = GetCanonicalSample(font_id, c);
757       FontClassInfo &fcinfo = (*font_class_array_)(font_index, c);
758       fcinfo.canonical_features = sample->indexed_features();
759     }
760   }
761 }
762 
763 // Computes the combined set of features used by all the samples of each
764 // font/class combination. Use after ReplicateAndRandomizeSamples.
ComputeCloudFeatures(int feature_space_size)765 void TrainingSampleSet::ComputeCloudFeatures(int feature_space_size) {
766   ASSERT_HOST(font_class_array_ != nullptr);
767   int font_size = font_id_map_.CompactSize();
768   for (int font_index = 0; font_index < font_size; ++font_index) {
769     int font_id = font_id_map_.CompactToSparse(font_index);
770     for (int c = 0; c < unicharset_size_; ++c) {
771       int num_samples = NumClassSamples(font_id, c, false);
772       if (num_samples == 0) {
773         continue;
774       }
775       FontClassInfo &fcinfo = (*font_class_array_)(font_index, c);
776       fcinfo.cloud_features.Init(feature_space_size);
777       for (int s = 0; s < num_samples; ++s) {
778         const TrainingSample *sample = GetSample(font_id, c, s);
779         const std::vector<int> &sample_features = sample->indexed_features();
780         for (int sample_feature : sample_features) {
781           fcinfo.cloud_features.SetBit(sample_feature);
782         }
783       }
784     }
785   }
786 }
787 
788 // Adds all fonts of the given class to the shape.
AddAllFontsForClass(int class_id,Shape * shape) const789 void TrainingSampleSet::AddAllFontsForClass(int class_id, Shape *shape) const {
790   for (int f = 0; f < font_id_map_.CompactSize(); ++f) {
791     const int font_id = font_id_map_.CompactToSparse(f);
792     shape->AddToShape(class_id, font_id);
793   }
794 }
795 
796 #ifndef GRAPHICS_DISABLED
797 
798 // Display the samples with the given indexed feature that also match
799 // the given shape.
DisplaySamplesWithFeature(int f_index,const Shape & shape,const IntFeatureSpace & space,ScrollView::Color color,ScrollView * window) const800 void TrainingSampleSet::DisplaySamplesWithFeature(int f_index, const Shape &shape,
801                                                   const IntFeatureSpace &space,
802                                                   ScrollView::Color color,
803                                                   ScrollView *window) const {
804   for (int s = 0; s < num_raw_samples(); ++s) {
805     const TrainingSample *sample = GetSample(s);
806     if (shape.ContainsUnichar(sample->class_id())) {
807       std::vector<int> indexed_features;
808       space.IndexAndSortFeatures(sample->features(), sample->num_features(), &indexed_features);
809       for (int indexed_feature : indexed_features) {
810         if (indexed_feature == f_index) {
811           sample->DisplayFeatures(color, window);
812         }
813       }
814     }
815   }
816 }
817 
818 #endif // !GRAPHICS_DISABLED
819 
820 } // namespace tesseract.
821