1 ///////////////////////////////////////////////////////////////////////
2 // File:        mastertrainer.cpp
3 // Description: Trainer to build the MasterClassifier.
4 // Author:      Ray Smith
5 //
6 // (C) Copyright 2010, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 ///////////////////////////////////////////////////////////////////////
18 
19 // Include automatically generated configuration file if running autoconf.
20 #ifdef HAVE_CONFIG_H
21 #  include "config_auto.h"
22 #endif
23 
24 #include <allheaders.h>
25 #include <cmath>
26 #include <ctime>
27 #include "boxread.h"
28 #include "classify.h"
29 #include "errorcounter.h"
30 #include "featdefs.h"
31 #include "mastertrainer.h"
32 #include "sampleiterator.h"
33 #include "shapeclassifier.h"
34 #include "shapetable.h"
35 #ifndef GRAPHICS_DISABLED
36 #  include "svmnode.h"
37 #endif
38 
39 #include "scanutils.h"
40 
41 namespace tesseract {
42 
43 // Constants controlling clustering. With a low kMinClusteredShapes and a high
44 // kMaxUnicharsPerCluster, then kFontMergeDistance is the only limiting factor.
45 // Min number of shapes in the output.
46 const int kMinClusteredShapes = 1;
47 // Max number of unichars in any individual cluster.
48 const int kMaxUnicharsPerCluster = 2000;
49 // Mean font distance below which to merge fonts and unichars.
50 const float kFontMergeDistance = 0.025;
51 
MasterTrainer(NormalizationMode norm_mode,bool shape_analysis,bool replicate_samples,int debug_level)52 MasterTrainer::MasterTrainer(NormalizationMode norm_mode, bool shape_analysis,
53                              bool replicate_samples, int debug_level)
54     : norm_mode_(norm_mode),
55       samples_(fontinfo_table_),
56       junk_samples_(fontinfo_table_),
57       verify_samples_(fontinfo_table_),
58       charsetsize_(0),
59       enable_shape_analysis_(shape_analysis),
60       enable_replication_(replicate_samples),
61       fragments_(nullptr),
62       prev_unichar_id_(-1),
63       debug_level_(debug_level) {}
64 
~MasterTrainer()65 MasterTrainer::~MasterTrainer() {
66   delete[] fragments_;
67   for (auto &page_image : page_images_) {
68     page_image.destroy();
69   }
70 }
71 
72 // WARNING! Serialize/DeSerialize are only partial, providing
73 // enough data to get the samples back and display them.
74 // Writes to the given file. Returns false in case of error.
Serialize(FILE * fp) const75 bool MasterTrainer::Serialize(FILE *fp) const {
76   uint32_t value = norm_mode_;
77   if (!tesseract::Serialize(fp, &value)) {
78     return false;
79   }
80   if (!unicharset_.save_to_file(fp)) {
81     return false;
82   }
83   if (!feature_space_.Serialize(fp)) {
84     return false;
85   }
86   if (!samples_.Serialize(fp)) {
87     return false;
88   }
89   if (!junk_samples_.Serialize(fp)) {
90     return false;
91   }
92   if (!verify_samples_.Serialize(fp)) {
93     return false;
94   }
95   if (!master_shapes_.Serialize(fp)) {
96     return false;
97   }
98   if (!flat_shapes_.Serialize(fp)) {
99     return false;
100   }
101   if (!fontinfo_table_.Serialize(fp)) {
102     return false;
103   }
104   if (!tesseract::Serialize(fp, xheights_)) {
105     return false;
106   }
107   return true;
108 }
109 
110 // Load an initial unicharset, or set one up if the file cannot be read.
LoadUnicharset(const char * filename)111 void MasterTrainer::LoadUnicharset(const char *filename) {
112   if (!unicharset_.load_from_file(filename)) {
113     tprintf(
114         "Failed to load unicharset from file %s\n"
115         "Building unicharset for training from scratch...\n",
116         filename);
117     unicharset_.clear();
118     UNICHARSET initialized;
119     // Add special characters, as they were removed by the clear, but the
120     // default constructor puts them in.
121     unicharset_.AppendOtherUnicharset(initialized);
122   }
123   charsetsize_ = unicharset_.size();
124   delete[] fragments_;
125   fragments_ = new int[charsetsize_];
126   memset(fragments_, 0, sizeof(*fragments_) * charsetsize_);
127   samples_.LoadUnicharset(filename);
128   junk_samples_.LoadUnicharset(filename);
129   verify_samples_.LoadUnicharset(filename);
130 }
131 
132 // Reads the samples and their features from the given .tr format file,
133 // adding them to the trainer with the font_id from the content of the file.
134 // See mftraining.cpp for a description of the file format.
135 // If verification, then these are verification samples, not training.
ReadTrainingSamples(const char * page_name,const FEATURE_DEFS_STRUCT & feature_defs,bool verification)136 void MasterTrainer::ReadTrainingSamples(const char *page_name,
137                                         const FEATURE_DEFS_STRUCT &feature_defs,
138                                         bool verification) {
139   char buffer[2048];
140   const int int_feature_type =
141       ShortNameToFeatureType(feature_defs, kIntFeatureType);
142   const int micro_feature_type =
143       ShortNameToFeatureType(feature_defs, kMicroFeatureType);
144   const int cn_feature_type =
145       ShortNameToFeatureType(feature_defs, kCNFeatureType);
146   const int geo_feature_type =
147       ShortNameToFeatureType(feature_defs, kGeoFeatureType);
148 
149   FILE *fp = fopen(page_name, "rb");
150   if (fp == nullptr) {
151     tprintf("Failed to open tr file: %s\n", page_name);
152     return;
153   }
154   tr_filenames_.emplace_back(page_name);
155   while (fgets(buffer, sizeof(buffer), fp) != nullptr) {
156     if (buffer[0] == '\n') {
157       continue;
158     }
159 
160     char *space = strchr(buffer, ' ');
161     if (space == nullptr) {
162       tprintf("Bad format in tr file, reading fontname, unichar\n");
163       continue;
164     }
165     *space++ = '\0';
166     int font_id = GetFontInfoId(buffer);
167     if (font_id < 0) {
168       font_id = 0;
169     }
170     int page_number;
171     std::string unichar;
172     TBOX bounding_box;
173     if (!ParseBoxFileStr(space, &page_number, unichar, &bounding_box)) {
174       tprintf("Bad format in tr file, reading box coords\n");
175       continue;
176     }
177     auto char_desc = ReadCharDescription(feature_defs, fp);
178     auto *sample = new TrainingSample;
179     sample->set_font_id(font_id);
180     sample->set_page_num(page_number + page_images_.size());
181     sample->set_bounding_box(bounding_box);
182     sample->ExtractCharDesc(int_feature_type, micro_feature_type,
183                             cn_feature_type, geo_feature_type, char_desc);
184     AddSample(verification, unichar.c_str(), sample);
185     delete char_desc;
186   }
187   charsetsize_ = unicharset_.size();
188   fclose(fp);
189 }
190 
191 // Adds the given single sample to the trainer, setting the classid
192 // appropriately from the given unichar_str.
AddSample(bool verification,const char * unichar,TrainingSample * sample)193 void MasterTrainer::AddSample(bool verification, const char *unichar,
194                               TrainingSample *sample) {
195   if (verification) {
196     verify_samples_.AddSample(unichar, sample);
197     prev_unichar_id_ = -1;
198   } else if (unicharset_.contains_unichar(unichar)) {
199     if (prev_unichar_id_ >= 0) {
200       fragments_[prev_unichar_id_] = -1;
201     }
202     prev_unichar_id_ = samples_.AddSample(unichar, sample);
203     if (flat_shapes_.FindShape(prev_unichar_id_, sample->font_id()) < 0) {
204       flat_shapes_.AddShape(prev_unichar_id_, sample->font_id());
205     }
206   } else {
207     const int junk_id = junk_samples_.AddSample(unichar, sample);
208     if (prev_unichar_id_ >= 0) {
209       CHAR_FRAGMENT *frag = CHAR_FRAGMENT::parse_from_string(unichar);
210       if (frag != nullptr && frag->is_natural()) {
211         if (fragments_[prev_unichar_id_] == 0) {
212           fragments_[prev_unichar_id_] = junk_id;
213         } else if (fragments_[prev_unichar_id_] != junk_id) {
214           fragments_[prev_unichar_id_] = -1;
215         }
216       }
217       delete frag;
218     }
219     prev_unichar_id_ = -1;
220   }
221 }
222 
223 // Loads all pages from the given tif filename and append to page_images_.
224 // Must be called after ReadTrainingSamples, as the current number of images
225 // is used as an offset for page numbers in the samples.
LoadPageImages(const char * filename)226 void MasterTrainer::LoadPageImages(const char *filename) {
227   size_t offset = 0;
228   int page;
229   Image pix;
230   for (page = 0;; page++) {
231     pix = pixReadFromMultipageTiff(filename, &offset);
232     if (!pix) {
233       break;
234     }
235     page_images_.push_back(pix);
236     if (!offset) {
237       break;
238     }
239   }
240   tprintf("Loaded %d page images from %s\n", page, filename);
241 }
242 
243 // Cleans up the samples after initial load from the tr files, and prior to
244 // saving the MasterTrainer:
245 // Remaps fragmented chars if running shape analysis.
246 // Sets up the samples appropriately for class/fontwise access.
247 // Deletes outlier samples.
PostLoadCleanup()248 void MasterTrainer::PostLoadCleanup() {
249   if (debug_level_ > 0) {
250     tprintf("PostLoadCleanup...\n");
251   }
252   if (enable_shape_analysis_) {
253     ReplaceFragmentedSamples();
254   }
255   SampleIterator sample_it;
256   sample_it.Init(nullptr, nullptr, true, &verify_samples_);
257   sample_it.NormalizeSamples();
258   verify_samples_.OrganizeByFontAndClass();
259 
260   samples_.IndexFeatures(feature_space_);
261   // TODO(rays) DeleteOutliers is currently turned off to prove NOP-ness
262   // against current training.
263   //  samples_.DeleteOutliers(feature_space_, debug_level_ > 0);
264   samples_.OrganizeByFontAndClass();
265   if (debug_level_ > 0) {
266     tprintf("ComputeCanonicalSamples...\n");
267   }
268   samples_.ComputeCanonicalSamples(feature_map_, debug_level_ > 0);
269 }
270 
271 // Gets the samples ready for training. Use after both
272 // ReadTrainingSamples+PostLoadCleanup or DeSerialize.
273 // Re-indexes the features and computes canonical and cloud features.
PreTrainingSetup()274 void MasterTrainer::PreTrainingSetup() {
275   if (debug_level_ > 0) {
276     tprintf("PreTrainingSetup...\n");
277   }
278   samples_.IndexFeatures(feature_space_);
279   samples_.ComputeCanonicalFeatures();
280   if (debug_level_ > 0) {
281     tprintf("ComputeCloudFeatures...\n");
282   }
283   samples_.ComputeCloudFeatures(feature_space_.Size());
284 }
285 
286 // Sets up the master_shapes_ table, which tells which fonts should stay
287 // together until they get to a leaf node classifier.
SetupMasterShapes()288 void MasterTrainer::SetupMasterShapes() {
289   tprintf("Building master shape table\n");
290   const int num_fonts = samples_.NumFonts();
291 
292   ShapeTable char_shapes_begin_fragment(samples_.unicharset());
293   ShapeTable char_shapes_end_fragment(samples_.unicharset());
294   ShapeTable char_shapes(samples_.unicharset());
295   for (int c = 0; c < samples_.charsetsize(); ++c) {
296     ShapeTable shapes(samples_.unicharset());
297     for (int f = 0; f < num_fonts; ++f) {
298       if (samples_.NumClassSamples(f, c, true) > 0) {
299         shapes.AddShape(c, f);
300       }
301     }
302     ClusterShapes(kMinClusteredShapes, 1, kFontMergeDistance, &shapes);
303 
304     const CHAR_FRAGMENT *fragment = samples_.unicharset().get_fragment(c);
305 
306     if (fragment == nullptr) {
307       char_shapes.AppendMasterShapes(shapes, nullptr);
308     } else if (fragment->is_beginning()) {
309       char_shapes_begin_fragment.AppendMasterShapes(shapes, nullptr);
310     } else if (fragment->is_ending()) {
311       char_shapes_end_fragment.AppendMasterShapes(shapes, nullptr);
312     } else {
313       char_shapes.AppendMasterShapes(shapes, nullptr);
314     }
315   }
316   ClusterShapes(kMinClusteredShapes, kMaxUnicharsPerCluster, kFontMergeDistance,
317                 &char_shapes_begin_fragment);
318   char_shapes.AppendMasterShapes(char_shapes_begin_fragment, nullptr);
319   ClusterShapes(kMinClusteredShapes, kMaxUnicharsPerCluster, kFontMergeDistance,
320                 &char_shapes_end_fragment);
321   char_shapes.AppendMasterShapes(char_shapes_end_fragment, nullptr);
322   ClusterShapes(kMinClusteredShapes, kMaxUnicharsPerCluster, kFontMergeDistance,
323                 &char_shapes);
324   master_shapes_.AppendMasterShapes(char_shapes, nullptr);
325   tprintf("Master shape_table:%s\n", master_shapes_.SummaryStr().c_str());
326 }
327 
328 // Adds the junk_samples_ to the main samples_ set. Junk samples are initially
329 // fragments and n-grams (all incorrectly segmented characters).
330 // Various training functions may result in incorrectly segmented characters
331 // being added to the unicharset of the main samples, perhaps because they
332 // form a "radical" decomposition of some (Indic) grapheme, or because they
333 // just look the same as a real character (like rn/m)
334 // This function moves all the junk samples, to the main samples_ set, but
335 // desirable junk, being any sample for which the unichar already exists in
336 // the samples_ unicharset gets the unichar-ids re-indexed to match, but
337 // anything else gets re-marked as unichar_id 0 (space character) to identify
338 // it as junk to the error counter.
IncludeJunk()339 void MasterTrainer::IncludeJunk() {
340   // Get ids of fragments in junk_samples_ that replace the dead chars.
341   const UNICHARSET &junk_set = junk_samples_.unicharset();
342   const UNICHARSET &sample_set = samples_.unicharset();
343   int num_junks = junk_samples_.num_samples();
344   tprintf("Moving %d junk samples to master sample set.\n", num_junks);
345   for (int s = 0; s < num_junks; ++s) {
346     TrainingSample *sample = junk_samples_.mutable_sample(s);
347     int junk_id = sample->class_id();
348     const char *junk_utf8 = junk_set.id_to_unichar(junk_id);
349     int sample_id = sample_set.unichar_to_id(junk_utf8);
350     if (sample_id == INVALID_UNICHAR_ID) {
351       sample_id = 0;
352     }
353     sample->set_class_id(sample_id);
354     junk_samples_.extract_sample(s);
355     samples_.AddSample(sample_id, sample);
356   }
357   junk_samples_.DeleteDeadSamples();
358   samples_.OrganizeByFontAndClass();
359 }
360 
361 // Replicates the samples and perturbs them if the enable_replication_ flag
362 // is set. MUST be used after the last call to OrganizeByFontAndClass on
363 // the training samples, ie after IncludeJunk if it is going to be used, as
364 // OrganizeByFontAndClass will eat the replicated samples into the regular
365 // samples.
ReplicateAndRandomizeSamplesIfRequired()366 void MasterTrainer::ReplicateAndRandomizeSamplesIfRequired() {
367   if (enable_replication_) {
368     if (debug_level_ > 0) {
369       tprintf("ReplicateAndRandomize...\n");
370     }
371     verify_samples_.ReplicateAndRandomizeSamples();
372     samples_.ReplicateAndRandomizeSamples();
373     samples_.IndexFeatures(feature_space_);
374   }
375 }
376 
377 // Loads the basic font properties file into fontinfo_table_.
378 // Returns false on failure.
LoadFontInfo(const char * filename)379 bool MasterTrainer::LoadFontInfo(const char *filename) {
380   FILE *fp = fopen(filename, "rb");
381   if (fp == nullptr) {
382     fprintf(stderr, "Failed to load font_properties from %s\n", filename);
383     return false;
384   }
385   int italic, bold, fixed, serif, fraktur;
386   while (!feof(fp)) {
387     FontInfo fontinfo;
388     char *font_name = new char[1024];
389     fontinfo.name = font_name;
390     fontinfo.properties = 0;
391     fontinfo.universal_id = 0;
392     if (tfscanf(fp, "%1024s %i %i %i %i %i\n", font_name, &italic, &bold,
393                 &fixed, &serif, &fraktur) != 6) {
394       delete[] font_name;
395       continue;
396     }
397     fontinfo.properties = (italic << 0) + (bold << 1) + (fixed << 2) +
398                           (serif << 3) + (fraktur << 4);
399     if (fontinfo_table_.get_index(fontinfo) < 0) {
400       // fontinfo not in table.
401       fontinfo_table_.push_back(fontinfo);
402     } else {
403       delete[] font_name;
404     }
405   }
406   fclose(fp);
407   return true;
408 }
409 
410 // Loads the xheight font properties file into xheights_.
411 // Returns false on failure.
LoadXHeights(const char * filename)412 bool MasterTrainer::LoadXHeights(const char *filename) {
413   tprintf("fontinfo table is of size %d\n", fontinfo_table_.size());
414   xheights_.clear();
415   xheights_.resize(fontinfo_table_.size(), -1);
416   if (filename == nullptr) {
417     return true;
418   }
419   FILE *f = fopen(filename, "rb");
420   if (f == nullptr) {
421     fprintf(stderr, "Failed to load font xheights from %s\n", filename);
422     return false;
423   }
424   tprintf("Reading x-heights from %s ...\n", filename);
425   FontInfo fontinfo;
426   fontinfo.properties = 0; // Not used to lookup in the table.
427   fontinfo.universal_id = 0;
428   char buffer[1024];
429   int xht;
430   int total_xheight = 0;
431   int xheight_count = 0;
432   while (!feof(f)) {
433     if (tfscanf(f, "%1023s %d\n", buffer, &xht) != 2) {
434       continue;
435     }
436     buffer[1023] = '\0';
437     fontinfo.name = buffer;
438     auto fontinfo_id = fontinfo_table_.get_index(fontinfo);
439     if (fontinfo_id < 0) {
440       // fontinfo not in table.
441       continue;
442     }
443     xheights_[fontinfo_id] = xht;
444     total_xheight += xht;
445     ++xheight_count;
446   }
447   if (xheight_count == 0) {
448     fprintf(stderr, "No valid xheights in %s!\n", filename);
449     fclose(f);
450     return false;
451   }
452   int mean_xheight = DivRounded(total_xheight, xheight_count);
453   for (int i = 0; i < fontinfo_table_.size(); ++i) {
454     if (xheights_[i] < 0) {
455       xheights_[i] = mean_xheight;
456     }
457   }
458   fclose(f);
459   return true;
460 } // LoadXHeights
461 
462 // Reads spacing stats from filename and adds them to fontinfo_table.
AddSpacingInfo(const char * filename)463 bool MasterTrainer::AddSpacingInfo(const char *filename) {
464   FILE *fontinfo_file = fopen(filename, "rb");
465   if (fontinfo_file == nullptr) {
466     return true; // We silently ignore missing files!
467   }
468   // Find the fontinfo_id.
469   int fontinfo_id = GetBestMatchingFontInfoId(filename);
470   if (fontinfo_id < 0) {
471     tprintf("No font found matching fontinfo filename %s\n", filename);
472     fclose(fontinfo_file);
473     return false;
474   }
475   tprintf("Reading spacing from %s for font %d...\n", filename, fontinfo_id);
476   // TODO(rays) scale should probably be a double, but keep as an int for now
477   // to duplicate current behavior.
478   int scale = kBlnXHeight / xheights_[fontinfo_id];
479   int num_unichars;
480   char uch[UNICHAR_LEN];
481   char kerned_uch[UNICHAR_LEN];
482   int x_gap, x_gap_before, x_gap_after, num_kerned;
483   ASSERT_HOST(tfscanf(fontinfo_file, "%d\n", &num_unichars) == 1);
484   FontInfo *fi = &fontinfo_table_.at(fontinfo_id);
485   fi->init_spacing(unicharset_.size());
486   FontSpacingInfo *spacing = nullptr;
487   for (int l = 0; l < num_unichars; ++l) {
488     if (tfscanf(fontinfo_file, "%s %d %d %d", uch, &x_gap_before, &x_gap_after,
489                 &num_kerned) != 4) {
490       tprintf("Bad format of font spacing file %s\n", filename);
491       fclose(fontinfo_file);
492       return false;
493     }
494     bool valid = unicharset_.contains_unichar(uch);
495     if (valid) {
496       spacing = new FontSpacingInfo();
497       spacing->x_gap_before = static_cast<int16_t>(x_gap_before * scale);
498       spacing->x_gap_after = static_cast<int16_t>(x_gap_after * scale);
499     }
500     for (int k = 0; k < num_kerned; ++k) {
501       if (tfscanf(fontinfo_file, "%s %d", kerned_uch, &x_gap) != 2) {
502         tprintf("Bad format of font spacing file %s\n", filename);
503         fclose(fontinfo_file);
504         delete spacing;
505         return false;
506       }
507       if (!valid || !unicharset_.contains_unichar(kerned_uch)) {
508         continue;
509       }
510       spacing->kerned_unichar_ids.push_back(
511           unicharset_.unichar_to_id(kerned_uch));
512       spacing->kerned_x_gaps.push_back(static_cast<int16_t>(x_gap * scale));
513     }
514     if (valid) {
515       fi->add_spacing(unicharset_.unichar_to_id(uch), spacing);
516     }
517   }
518   fclose(fontinfo_file);
519   return true;
520 }
521 
522 // Returns the font id corresponding to the given font name.
523 // Returns -1 if the font cannot be found.
GetFontInfoId(const char * font_name)524 int MasterTrainer::GetFontInfoId(const char *font_name) {
525   FontInfo fontinfo;
526   // We are only borrowing the string, so it is OK to const cast it.
527   fontinfo.name = const_cast<char *>(font_name);
528   fontinfo.properties = 0; // Not used to lookup in the table
529   fontinfo.universal_id = 0;
530   return fontinfo_table_.get_index(fontinfo);
531 }
532 // Returns the font_id of the closest matching font name to the given
533 // filename. It is assumed that a substring of the filename will match
534 // one of the fonts. If more than one is matched, the longest is returned.
GetBestMatchingFontInfoId(const char * filename)535 int MasterTrainer::GetBestMatchingFontInfoId(const char *filename) {
536   int fontinfo_id = -1;
537   int best_len = 0;
538   for (int f = 0; f < fontinfo_table_.size(); ++f) {
539     if (strstr(filename, fontinfo_table_.at(f).name) != nullptr) {
540       int len = strlen(fontinfo_table_.at(f).name);
541       // Use the longest matching length in case a substring of a font matched.
542       if (len > best_len) {
543         best_len = len;
544         fontinfo_id = f;
545       }
546     }
547   }
548   return fontinfo_id;
549 }
550 
551 // Sets up a flat shapetable with one shape per class/font combination.
SetupFlatShapeTable(ShapeTable * shape_table)552 void MasterTrainer::SetupFlatShapeTable(ShapeTable *shape_table) {
553   // To exactly mimic the results of the previous implementation, the shapes
554   // must be clustered in order the fonts arrived, and reverse order of the
555   // characters within each font.
556   // Get a list of the fonts in the order they appeared.
557   std::vector<int> active_fonts;
558   int num_shapes = flat_shapes_.NumShapes();
559   for (int s = 0; s < num_shapes; ++s) {
560     int font = flat_shapes_.GetShape(s)[0].font_ids[0];
561     unsigned f = 0;
562     for (f = 0; f < active_fonts.size(); ++f) {
563       if (active_fonts[f] == font) {
564         break;
565       }
566     }
567     if (f == active_fonts.size()) {
568       active_fonts.push_back(font);
569     }
570   }
571   // For each font in order, add all the shapes with that font in reverse order.
572   int num_fonts = active_fonts.size();
573   for (int f = 0; f < num_fonts; ++f) {
574     for (int s = num_shapes - 1; s >= 0; --s) {
575       int font = flat_shapes_.GetShape(s)[0].font_ids[0];
576       if (font == active_fonts[f]) {
577         shape_table->AddShape(flat_shapes_.GetShape(s));
578       }
579     }
580   }
581 }
582 
583 // Sets up a Clusterer for mftraining on a single shape_id.
584 // Call FreeClusterer on the return value after use.
SetupForClustering(const ShapeTable & shape_table,const FEATURE_DEFS_STRUCT & feature_defs,int shape_id,int * num_samples)585 CLUSTERER *MasterTrainer::SetupForClustering(
586     const ShapeTable &shape_table, const FEATURE_DEFS_STRUCT &feature_defs,
587     int shape_id, int *num_samples) {
588   int desc_index = ShortNameToFeatureType(feature_defs, kMicroFeatureType);
589   int num_params = feature_defs.FeatureDesc[desc_index]->NumParams;
590   ASSERT_HOST(num_params == (int)MicroFeatureParameter::MFCount);
591   CLUSTERER *clusterer = MakeClusterer(
592       num_params, feature_defs.FeatureDesc[desc_index]->ParamDesc);
593 
594   // We want to iterate over the samples of just the one shape.
595   IndexMapBiDi shape_map;
596   shape_map.Init(shape_table.NumShapes(), false);
597   shape_map.SetMap(shape_id, true);
598   shape_map.Setup();
599   // Reverse the order of the samples to match the previous behavior.
600   std::vector<const TrainingSample *> sample_ptrs;
601   SampleIterator it;
602   it.Init(&shape_map, &shape_table, false, &samples_);
603   for (it.Begin(); !it.AtEnd(); it.Next()) {
604     sample_ptrs.push_back(&it.GetSample());
605   }
606   uint32_t sample_id = 0;
607   for (int i = sample_ptrs.size() - 1; i >= 0; --i) {
608     const TrainingSample *sample = sample_ptrs[i];
609     uint32_t num_features = sample->num_micro_features();
610     for (uint32_t f = 0; f < num_features; ++f) {
611       MakeSample(clusterer, sample->micro_features()[f].data(), sample_id);
612     }
613     ++sample_id;
614   }
615   *num_samples = sample_id;
616   return clusterer;
617 }
618 
619 // Writes the given float_classes (produced by SetupForFloat2Int) as inttemp
620 // to the given inttemp_file, and the corresponding pffmtable.
621 // The unicharset is the original encoding of graphemes, and shape_set should
622 // match the size of the shape_table, and may possibly be totally fake.
WriteInttempAndPFFMTable(const UNICHARSET & unicharset,const UNICHARSET & shape_set,const ShapeTable & shape_table,CLASS_STRUCT * float_classes,const char * inttemp_file,const char * pffmtable_file)623 void MasterTrainer::WriteInttempAndPFFMTable(const UNICHARSET &unicharset,
624                                              const UNICHARSET &shape_set,
625                                              const ShapeTable &shape_table,
626                                              CLASS_STRUCT *float_classes,
627                                              const char *inttemp_file,
628                                              const char *pffmtable_file) {
629   auto *classify = new tesseract::Classify();
630   // Move the fontinfo table to classify.
631   fontinfo_table_.MoveTo(&classify->get_fontinfo_table());
632   INT_TEMPLATES_STRUCT *int_templates =
633       classify->CreateIntTemplates(float_classes, shape_set);
634   FILE *fp = fopen(inttemp_file, "wb");
635   if (fp == nullptr) {
636     tprintf("Error, failed to open file \"%s\"\n", inttemp_file);
637   } else {
638     classify->WriteIntTemplates(fp, int_templates, shape_set);
639     fclose(fp);
640   }
641   // Now write pffmtable. This is complicated by the fact that the adaptive
642   // classifier still wants one indexed by unichar-id, but the static
643   // classifier needs one indexed by its shape class id.
644   // We put the shapetable_cutoffs in a vector, and compute the
645   // unicharset cutoffs along the way.
646   std::vector<uint16_t> shapetable_cutoffs;
647   std::vector<uint16_t> unichar_cutoffs(unicharset.size());
648   /* then write out each class */
649   for (int i = 0; i < int_templates->NumClasses; ++i) {
650     INT_CLASS_STRUCT *Class = ClassForClassId(int_templates, i);
651     // Todo: Test with min instead of max
652     // int MaxLength = LengthForConfigId(Class, 0);
653     uint16_t max_length = 0;
654     for (int config_id = 0; config_id < Class->NumConfigs; config_id++) {
655       // Todo: Test with min instead of max
656       // if (LengthForConfigId (Class, config_id) < MaxLength)
657       uint16_t length = Class->ConfigLengths[config_id];
658       if (length > max_length) {
659         max_length = Class->ConfigLengths[config_id];
660       }
661       int shape_id = float_classes[i].font_set.at(config_id);
662       const Shape &shape = shape_table.GetShape(shape_id);
663       for (int c = 0; c < shape.size(); ++c) {
664         int unichar_id = shape[c].unichar_id;
665         if (length > unichar_cutoffs[unichar_id]) {
666           unichar_cutoffs[unichar_id] = length;
667         }
668       }
669     }
670     shapetable_cutoffs.push_back(max_length);
671   }
672   fp = fopen(pffmtable_file, "wb");
673   if (fp == nullptr) {
674     tprintf("Error, failed to open file \"%s\"\n", pffmtable_file);
675   } else {
676     tesseract::Serialize(fp, shapetable_cutoffs);
677     for (int c = 0; c < unicharset.size(); ++c) {
678       const char *unichar = unicharset.id_to_unichar(c);
679       if (strcmp(unichar, " ") == 0) {
680         unichar = "NULL";
681       }
682       fprintf(fp, "%s %d\n", unichar, unichar_cutoffs[c]);
683     }
684     fclose(fp);
685   }
686   delete int_templates;
687   delete classify;
688 }
689 
690 // Generate debug output relating to the canonical distance between the
691 // two given UTF8 grapheme strings.
DebugCanonical(const char * unichar_str1,const char * unichar_str2)692 void MasterTrainer::DebugCanonical(const char *unichar_str1,
693                                    const char *unichar_str2) {
694   int class_id1 = unicharset_.unichar_to_id(unichar_str1);
695   int class_id2 = unicharset_.unichar_to_id(unichar_str2);
696   if (class_id2 == INVALID_UNICHAR_ID) {
697     class_id2 = class_id1;
698   }
699   if (class_id1 == INVALID_UNICHAR_ID) {
700     tprintf("No unicharset entry found for %s\n", unichar_str1);
701     return;
702   } else {
703     tprintf("Font ambiguities for unichar %d = %s and %d = %s\n", class_id1,
704             unichar_str1, class_id2, unichar_str2);
705   }
706   int num_fonts = samples_.NumFonts();
707   const IntFeatureMap &feature_map = feature_map_;
708   // Iterate the fonts to get the similarity with other fonst of the same
709   // class.
710   tprintf("      ");
711   for (int f = 0; f < num_fonts; ++f) {
712     if (samples_.NumClassSamples(f, class_id2, false) == 0) {
713       continue;
714     }
715     tprintf("%6d", f);
716   }
717   tprintf("\n");
718   for (int f1 = 0; f1 < num_fonts; ++f1) {
719     // Map the features of the canonical_sample.
720     if (samples_.NumClassSamples(f1, class_id1, false) == 0) {
721       continue;
722     }
723     tprintf("%4d  ", f1);
724     for (int f2 = 0; f2 < num_fonts; ++f2) {
725       if (samples_.NumClassSamples(f2, class_id2, false) == 0) {
726         continue;
727       }
728       float dist =
729           samples_.ClusterDistance(f1, class_id1, f2, class_id2, feature_map);
730       tprintf(" %5.3f", dist);
731     }
732     tprintf("\n");
733   }
734   // Build a fake ShapeTable containing all the sample types.
735   ShapeTable shapes(unicharset_);
736   for (int f = 0; f < num_fonts; ++f) {
737     if (samples_.NumClassSamples(f, class_id1, true) > 0) {
738       shapes.AddShape(class_id1, f);
739     }
740     if (class_id1 != class_id2 &&
741         samples_.NumClassSamples(f, class_id2, true) > 0) {
742       shapes.AddShape(class_id2, f);
743     }
744   }
745 }
746 
747 #ifndef GRAPHICS_DISABLED
748 // Debugging for cloud/canonical features.
749 // Displays a Features window containing:
750 // If unichar_str2 is in the unicharset, and canonical_font is non-negative,
751 // displays the canonical features of the char/font combination in red.
752 // If unichar_str1 is in the unicharset, and cloud_font is non-negative,
753 // displays the cloud feature of the char/font combination in green.
754 // The canonical features are drawn first to show which ones have no
755 // matches in the cloud features.
756 // Until the features window is destroyed, each click in the features window
757 // will display the samples that have that feature in a separate window.
DisplaySamples(const char * unichar_str1,int cloud_font,const char * unichar_str2,int canonical_font)758 void MasterTrainer::DisplaySamples(const char *unichar_str1, int cloud_font,
759                                    const char *unichar_str2,
760                                    int canonical_font) {
761   const IntFeatureMap &feature_map = feature_map_;
762   const IntFeatureSpace &feature_space = feature_map.feature_space();
763   ScrollView *f_window = CreateFeatureSpaceWindow("Features", 100, 500);
764   ClearFeatureSpaceWindow(norm_mode_ == NM_BASELINE ? baseline : character,
765                           f_window);
766   int class_id2 = samples_.unicharset().unichar_to_id(unichar_str2);
767   if (class_id2 != INVALID_UNICHAR_ID && canonical_font >= 0) {
768     const TrainingSample *sample =
769         samples_.GetCanonicalSample(canonical_font, class_id2);
770     for (uint32_t f = 0; f < sample->num_features(); ++f) {
771       RenderIntFeature(f_window, &sample->features()[f], ScrollView::RED);
772     }
773   }
774   int class_id1 = samples_.unicharset().unichar_to_id(unichar_str1);
775   if (class_id1 != INVALID_UNICHAR_ID && cloud_font >= 0) {
776     const BitVector &cloud = samples_.GetCloudFeatures(cloud_font, class_id1);
777     for (int f = 0; f < cloud.size(); ++f) {
778       if (cloud[f]) {
779         INT_FEATURE_STRUCT feature = feature_map.InverseIndexFeature(f);
780         RenderIntFeature(f_window, &feature, ScrollView::GREEN);
781       }
782     }
783   }
784   f_window->Update();
785   ScrollView *s_window = CreateFeatureSpaceWindow("Samples", 100, 500);
786   SVEventType ev_type;
787   do {
788     SVEvent *ev;
789     // Wait until a click or popup event.
790     ev = f_window->AwaitEvent(SVET_ANY);
791     ev_type = ev->type;
792     if (ev_type == SVET_CLICK) {
793       int feature_index = feature_space.XYToFeatureIndex(ev->x, ev->y);
794       if (feature_index >= 0) {
795         // Iterate samples and display those with the feature.
796         Shape shape;
797         shape.AddToShape(class_id1, cloud_font);
798         s_window->Clear();
799         samples_.DisplaySamplesWithFeature(feature_index, shape, feature_space,
800                                            ScrollView::GREEN, s_window);
801         s_window->Update();
802       }
803     }
804     delete ev;
805   } while (ev_type != SVET_DESTROY);
806 }
807 #endif // !GRAPHICS_DISABLED
808 
TestClassifierVOld(bool replicate_samples,ShapeClassifier * test_classifier,ShapeClassifier * old_classifier)809 void MasterTrainer::TestClassifierVOld(bool replicate_samples,
810                                        ShapeClassifier *test_classifier,
811                                        ShapeClassifier *old_classifier) {
812   SampleIterator sample_it;
813   sample_it.Init(nullptr, nullptr, replicate_samples, &samples_);
814   ErrorCounter::DebugNewErrors(test_classifier, old_classifier,
815                                CT_UNICHAR_TOPN_ERR, fontinfo_table_,
816                                page_images_, &sample_it);
817 }
818 
819 // Tests the given test_classifier on the internal samples.
820 // See TestClassifier for details.
TestClassifierOnSamples(CountTypes error_mode,int report_level,bool replicate_samples,ShapeClassifier * test_classifier,std::string * report_string)821 void MasterTrainer::TestClassifierOnSamples(CountTypes error_mode,
822                                             int report_level,
823                                             bool replicate_samples,
824                                             ShapeClassifier *test_classifier,
825                                             std::string *report_string) {
826   TestClassifier(error_mode, report_level, replicate_samples, &samples_,
827                  test_classifier, report_string);
828 }
829 
830 // Tests the given test_classifier on the given samples.
831 // error_mode indicates what counts as an error.
832 // report_levels:
833 // 0 = no output.
834 // 1 = bottom-line error rate.
835 // 2 = bottom-line error rate + time.
836 // 3 = font-level error rate + time.
837 // 4 = list of all errors + short classifier debug output on 16 errors.
838 // 5 = list of all errors + short classifier debug output on 25 errors.
839 // If replicate_samples is true, then the test is run on an extended test
840 // sample including replicated and systematically perturbed samples.
841 // If report_string is non-nullptr, a summary of the results for each font
842 // is appended to the report_string.
TestClassifier(CountTypes error_mode,int report_level,bool replicate_samples,TrainingSampleSet * samples,ShapeClassifier * test_classifier,std::string * report_string)843 double MasterTrainer::TestClassifier(CountTypes error_mode, int report_level,
844                                      bool replicate_samples,
845                                      TrainingSampleSet *samples,
846                                      ShapeClassifier *test_classifier,
847                                      std::string *report_string) {
848   SampleIterator sample_it;
849   sample_it.Init(nullptr, nullptr, replicate_samples, samples);
850   if (report_level > 0) {
851     int num_samples = 0;
852     for (sample_it.Begin(); !sample_it.AtEnd(); sample_it.Next()) {
853       ++num_samples;
854     }
855     tprintf("Iterator has charset size of %d/%d, %d shapes, %d samples\n",
856             sample_it.SparseCharsetSize(), sample_it.CompactCharsetSize(),
857             test_classifier->GetShapeTable()->NumShapes(), num_samples);
858     tprintf("Testing %sREPLICATED:\n", replicate_samples ? "" : "NON-");
859   }
860   double unichar_error = 0.0;
861   ErrorCounter::ComputeErrorRate(test_classifier, report_level, error_mode,
862                                  fontinfo_table_, page_images_, &sample_it,
863                                  &unichar_error, nullptr, report_string);
864   return unichar_error;
865 }
866 
867 // Returns the average (in some sense) distance between the two given
868 // shapes, which may contain multiple fonts and/or unichars.
ShapeDistance(const ShapeTable & shapes,int s1,int s2)869 float MasterTrainer::ShapeDistance(const ShapeTable &shapes, int s1, int s2) {
870   const IntFeatureMap &feature_map = feature_map_;
871   const Shape &shape1 = shapes.GetShape(s1);
872   const Shape &shape2 = shapes.GetShape(s2);
873   int num_chars1 = shape1.size();
874   int num_chars2 = shape2.size();
875   float dist_sum = 0.0f;
876   int dist_count = 0;
877   if (num_chars1 > 1 || num_chars2 > 1) {
878     // In the multi-char case try to optimize the calculation by computing
879     // distances between characters of matching font where possible.
880     for (int c1 = 0; c1 < num_chars1; ++c1) {
881       for (int c2 = 0; c2 < num_chars2; ++c2) {
882         dist_sum +=
883             samples_.UnicharDistance(shape1[c1], shape2[c2], true, feature_map);
884         ++dist_count;
885       }
886     }
887   } else {
888     // In the single unichar case, there is little alternative, but to compute
889     // the squared-order distance between pairs of fonts.
890     dist_sum =
891         samples_.UnicharDistance(shape1[0], shape2[0], false, feature_map);
892     ++dist_count;
893   }
894   return dist_sum / dist_count;
895 }
896 
897 // Replaces samples that are always fragmented with the corresponding
898 // fragment samples.
ReplaceFragmentedSamples()899 void MasterTrainer::ReplaceFragmentedSamples() {
900   if (fragments_ == nullptr) {
901     return;
902   }
903   // Remove samples that are replaced by fragments. Each class that was
904   // always naturally fragmented should be replaced by its fragments.
905   int num_samples = samples_.num_samples();
906   for (int s = 0; s < num_samples; ++s) {
907     TrainingSample *sample = samples_.mutable_sample(s);
908     if (fragments_[sample->class_id()] > 0) {
909       samples_.KillSample(sample);
910     }
911   }
912   samples_.DeleteDeadSamples();
913 
914   // Get ids of fragments in junk_samples_ that replace the dead chars.
915   const UNICHARSET &frag_set = junk_samples_.unicharset();
916 #if 0
917   // TODO(rays) The original idea was to replace only graphemes that were
918   // always naturally fragmented, but that left a lot of the Indic graphemes
919   // out. Determine whether we can go back to that idea now that spacing
920   // is fixed in the training images, or whether this code is obsolete.
921   bool* good_junk = new bool[frag_set.size()];
922   memset(good_junk, 0, sizeof(*good_junk) * frag_set.size());
923   for (int dead_ch = 1; dead_ch < unicharset_.size(); ++dead_ch) {
924     int frag_ch = fragments_[dead_ch];
925     if (frag_ch <= 0) continue;
926     const char* frag_utf8 = frag_set.id_to_unichar(frag_ch);
927     CHAR_FRAGMENT* frag = CHAR_FRAGMENT::parse_from_string(frag_utf8);
928     // Mark the chars for all parts of the fragment as good in good_junk.
929     for (int part = 0; part < frag->get_total(); ++part) {
930       frag->set_pos(part);
931       int good_ch = frag_set.unichar_to_id(frag->to_string().c_str());
932       if (good_ch != INVALID_UNICHAR_ID)
933         good_junk[good_ch] = true;  // We want this one.
934     }
935     delete frag;
936   }
937 #endif
938   // For now just use all the junk that was from natural fragments.
939   // Get samples of fragments in junk_samples_ that replace the dead chars.
940   int num_junks = junk_samples_.num_samples();
941   for (int s = 0; s < num_junks; ++s) {
942     TrainingSample *sample = junk_samples_.mutable_sample(s);
943     int junk_id = sample->class_id();
944     const char *frag_utf8 = frag_set.id_to_unichar(junk_id);
945     CHAR_FRAGMENT *frag = CHAR_FRAGMENT::parse_from_string(frag_utf8);
946     if (frag != nullptr && frag->is_natural()) {
947       junk_samples_.extract_sample(s);
948       samples_.AddSample(frag_set.id_to_unichar(junk_id), sample);
949     }
950     delete frag;
951   }
952   junk_samples_.DeleteDeadSamples();
953   junk_samples_.OrganizeByFontAndClass();
954   samples_.OrganizeByFontAndClass();
955   unicharset_.clear();
956   unicharset_.AppendOtherUnicharset(samples_.unicharset());
957   // delete [] good_junk;
958   // Fragments_ no longer needed?
959   delete[] fragments_;
960   fragments_ = nullptr;
961 }
962 
963 // Runs a hierarchical agglomerative clustering to merge shapes in the given
964 // shape_table, while satisfying the given constraints:
965 // * End with at least min_shapes left in shape_table,
966 // * No shape shall have more than max_shape_unichars in it,
967 // * Don't merge shapes where the distance between them exceeds max_dist.
968 const float kInfiniteDist = 999.0f;
ClusterShapes(int min_shapes,int max_shape_unichars,float max_dist,ShapeTable * shapes)969 void MasterTrainer::ClusterShapes(int min_shapes, int max_shape_unichars,
970                                   float max_dist, ShapeTable *shapes) {
971   int num_shapes = shapes->NumShapes();
972   int max_merges = num_shapes - min_shapes;
973   // TODO: avoid new / delete.
974   auto *shape_dists = new std::vector<ShapeDist>[num_shapes];
975   float min_dist = kInfiniteDist;
976   int min_s1 = 0;
977   int min_s2 = 0;
978   tprintf("Computing shape distances...");
979   for (int s1 = 0; s1 < num_shapes; ++s1) {
980     for (int s2 = s1 + 1; s2 < num_shapes; ++s2) {
981       ShapeDist dist(s1, s2, ShapeDistance(*shapes, s1, s2));
982       shape_dists[s1].push_back(dist);
983       if (dist.distance < min_dist) {
984         min_dist = dist.distance;
985         min_s1 = s1;
986         min_s2 = s2;
987       }
988     }
989     tprintf(" %d", s1);
990   }
991   tprintf("\n");
992   int num_merged = 0;
993   while (num_merged < max_merges && min_dist < max_dist) {
994     tprintf("Distance = %f: ", min_dist);
995     int num_unichars = shapes->MergedUnicharCount(min_s1, min_s2);
996     shape_dists[min_s1][min_s2 - min_s1 - 1].distance = kInfiniteDist;
997     if (num_unichars > max_shape_unichars) {
998       tprintf("Merge of %d and %d with %d would exceed max of %d unichars\n",
999               min_s1, min_s2, num_unichars, max_shape_unichars);
1000     } else {
1001       shapes->MergeShapes(min_s1, min_s2);
1002       shape_dists[min_s2].clear();
1003       ++num_merged;
1004 
1005       for (int s = 0; s < min_s1; ++s) {
1006         if (!shape_dists[s].empty()) {
1007           shape_dists[s][min_s1 - s - 1].distance =
1008               ShapeDistance(*shapes, s, min_s1);
1009           shape_dists[s][min_s2 - s - 1].distance = kInfiniteDist;
1010         }
1011       }
1012       for (int s2 = min_s1 + 1; s2 < num_shapes; ++s2) {
1013         if (shape_dists[min_s1][s2 - min_s1 - 1].distance < kInfiniteDist) {
1014           shape_dists[min_s1][s2 - min_s1 - 1].distance =
1015               ShapeDistance(*shapes, min_s1, s2);
1016         }
1017       }
1018       for (int s = min_s1 + 1; s < min_s2; ++s) {
1019         if (!shape_dists[s].empty()) {
1020           shape_dists[s][min_s2 - s - 1].distance = kInfiniteDist;
1021         }
1022       }
1023     }
1024     min_dist = kInfiniteDist;
1025     for (int s1 = 0; s1 < num_shapes; ++s1) {
1026       for (unsigned i = 0; i < shape_dists[s1].size(); ++i) {
1027         if (shape_dists[s1][i].distance < min_dist) {
1028           min_dist = shape_dists[s1][i].distance;
1029           min_s1 = s1;
1030           min_s2 = s1 + 1 + i;
1031         }
1032       }
1033     }
1034   }
1035   tprintf("Stopped with %d merged, min dist %f\n", num_merged, min_dist);
1036   delete[] shape_dists;
1037   if (debug_level_ > 1) {
1038     for (int s1 = 0; s1 < num_shapes; ++s1) {
1039       if (shapes->MasterDestinationIndex(s1) == s1) {
1040         tprintf("Master shape:%s\n", shapes->DebugStr(s1).c_str());
1041       }
1042     }
1043   }
1044 }
1045 
1046 } // namespace tesseract.
1047