1 // Copyright 2011 Google Inc. All Rights Reserved.
2 // Author: rays@google.com (Ray Smith)
3 ///////////////////////////////////////////////////////////////////////
4 // File:        shapeclassifier.cpp
5 // Description: Base interface class for classifiers that return a
6 //              shape index.
7 // Author:      Ray Smith
8 //
9 // (C) Copyright 2011, Google Inc.
10 // Licensed under the Apache License, Version 2.0 (the "License");
11 // you may not use this file except in compliance with the License.
12 // You may obtain a copy of the License at
13 // http://www.apache.org/licenses/LICENSE-2.0
14 // Unless required by applicable law or agreed to in writing, software
15 // distributed under the License is distributed on an "AS IS" BASIS,
16 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 // See the License for the specific language governing permissions and
18 // limitations under the License.
19 //
20 ///////////////////////////////////////////////////////////////////////
21 
22 #ifdef HAVE_CONFIG_H
23 #  include "config_auto.h"
24 #endif
25 
26 #include "shapeclassifier.h"
27 
28 #include "scrollview.h"
29 #include "shapetable.h"
30 #ifndef GRAPHICS_DISABLED
31 #include "svmnode.h"
32 #endif
33 #include "tprintf.h"
34 #include "trainingsample.h"
35 
36 namespace tesseract {
37 
38 // Classifies the given [training] sample, writing to results.
39 // See shapeclassifier.h for a full description.
40 // Default implementation calls the ShapeRating version.
UnicharClassifySample(const TrainingSample & sample,Image page_pix,int debug,UNICHAR_ID keep_this,std::vector<UnicharRating> * results)41 int ShapeClassifier::UnicharClassifySample(const TrainingSample &sample, Image page_pix, int debug,
42                                            UNICHAR_ID keep_this,
43                                            std::vector<UnicharRating> *results) {
44   results->clear();
45   std::vector<ShapeRating> shape_results;
46   int num_shape_results = ClassifySample(sample, page_pix, debug, keep_this, &shape_results);
47   const ShapeTable *shapes = GetShapeTable();
48   std::vector<int> unichar_map(shapes->unicharset().size(), -1);
49   for (int r = 0; r < num_shape_results; ++r) {
50     shapes->AddShapeToResults(shape_results[r], &unichar_map, results);
51   }
52   return results->size();
53 }
54 
55 // Classifies the given [training] sample, writing to results.
56 // See shapeclassifier.h for a full description.
57 // Default implementation aborts.
ClassifySample(const TrainingSample & sample,Image page_pix,int debug,int keep_this,std::vector<ShapeRating> * results)58 int ShapeClassifier::ClassifySample(const TrainingSample &sample, Image page_pix, int debug,
59                                     int keep_this, std::vector<ShapeRating> *results) {
60   ASSERT_HOST("Must implement ClassifySample!" == nullptr);
61   return 0;
62 }
63 
64 // Returns the shape that contains unichar_id that has the best result.
65 // If result is not nullptr, it is set with the shape_id and rating.
66 // Does not need to be overridden if ClassifySample respects the keep_this
67 // rule.
BestShapeForUnichar(const TrainingSample & sample,Image page_pix,UNICHAR_ID unichar_id,ShapeRating * result)68 int ShapeClassifier::BestShapeForUnichar(const TrainingSample &sample, Image page_pix,
69                                          UNICHAR_ID unichar_id, ShapeRating *result) {
70   std::vector<ShapeRating> results;
71   const ShapeTable *shapes = GetShapeTable();
72   int num_results = ClassifySample(sample, page_pix, 0, unichar_id, &results);
73   for (int r = 0; r < num_results; ++r) {
74     if (shapes->GetShape(results[r].shape_id).ContainsUnichar(unichar_id)) {
75       if (result != nullptr) {
76         *result = results[r];
77       }
78       return results[r].shape_id;
79     }
80   }
81   return -1;
82 }
83 
84 // Provides access to the UNICHARSET that this classifier works with.
85 // Only needs to be overridden if GetShapeTable() can return nullptr.
GetUnicharset() const86 const UNICHARSET &ShapeClassifier::GetUnicharset() const {
87   return GetShapeTable()->unicharset();
88 }
89 
90 #ifndef GRAPHICS_DISABLED
91 
92 // Visual debugger classifies the given sample, displays the results and
93 // solicits user input to display other classifications. Returns when
94 // the user has finished with debugging the sample.
95 // Probably doesn't need to be overridden if the subclass provides
96 // DisplayClassifyAs.
DebugDisplay(const TrainingSample & sample,Image page_pix,UNICHAR_ID unichar_id)97 void ShapeClassifier::DebugDisplay(const TrainingSample &sample, Image page_pix,
98                                    UNICHAR_ID unichar_id) {
99   static ScrollView *terminator = nullptr;
100   if (terminator == nullptr) {
101     terminator = new ScrollView("XIT", 0, 0, 50, 50, 50, 50, true);
102   }
103   ScrollView *debug_win = CreateFeatureSpaceWindow("ClassifierDebug", 0, 0);
104   // Provide a right-click menu to choose the class.
105   auto *popup_menu = new SVMenuNode();
106   popup_menu->AddChild("Choose class to debug", 0, "x", "Class to debug");
107   popup_menu->BuildMenu(debug_win, false);
108   // Display the features in green.
109   const INT_FEATURE_STRUCT *features = sample.features();
110   uint32_t num_features = sample.num_features();
111   for (uint32_t f = 0; f < num_features; ++f) {
112     RenderIntFeature(debug_win, &features[f], ScrollView::GREEN);
113   }
114   debug_win->Update();
115   std::vector<UnicharRating> results;
116   // Debug classification until the user quits.
117   const UNICHARSET &unicharset = GetUnicharset();
118   SVEvent *ev;
119   SVEventType ev_type;
120   do {
121     std::vector<ScrollView *> windows;
122     if (unichar_id >= 0) {
123       tprintf("Debugging class %d = %s\n", unichar_id, unicharset.id_to_unichar(unichar_id));
124       UnicharClassifySample(sample, page_pix, 1, unichar_id, &results);
125       DisplayClassifyAs(sample, page_pix, unichar_id, 1, windows);
126     } else {
127       tprintf("Invalid unichar_id: %d\n", unichar_id);
128       UnicharClassifySample(sample, page_pix, 1, -1, &results);
129     }
130     if (unichar_id >= 0) {
131       tprintf("Debugged class %d = %s\n", unichar_id, unicharset.id_to_unichar(unichar_id));
132     }
133     tprintf("Right-click in ClassifierDebug window to choose debug class,");
134     tprintf(" Left-click or close window to quit...\n");
135     UNICHAR_ID old_unichar_id;
136     do {
137       old_unichar_id = unichar_id;
138       ev = debug_win->AwaitEvent(SVET_ANY);
139       ev_type = ev->type;
140       if (ev_type == SVET_POPUP) {
141         if (unicharset.contains_unichar(ev->parameter)) {
142           unichar_id = unicharset.unichar_to_id(ev->parameter);
143         } else {
144           tprintf("Char class '%s' not found in unicharset", ev->parameter);
145         }
146       }
147       delete ev;
148     } while (unichar_id == old_unichar_id && ev_type != SVET_CLICK && ev_type != SVET_DESTROY);
149     for (auto window : windows) {
150       delete window;
151     }
152   } while (ev_type != SVET_CLICK && ev_type != SVET_DESTROY);
153   delete debug_win;
154 }
155 
156 #endif // !GRAPHICS_DISABLED
157 
158 // Displays classification as the given shape_id. Creates as many windows
159 // as it feels fit, using index as a guide for placement. Adds any created
160 // windows to the windows output and returns a new index that may be used
161 // by any subsequent classifiers. Caller waits for the user to view and
162 // then destroys the windows by clearing the vector.
DisplayClassifyAs(const TrainingSample & sample,Image page_pix,UNICHAR_ID unichar_id,int index,std::vector<ScrollView * > & windows)163 int ShapeClassifier::DisplayClassifyAs(const TrainingSample &sample, Image page_pix,
164                                        UNICHAR_ID unichar_id, int index,
165                                        std::vector<ScrollView *> &windows) {
166   // Does nothing in the default implementation.
167   return index;
168 }
169 
170 // Prints debug information on the results.
UnicharPrintResults(const char * context,const std::vector<UnicharRating> & results) const171 void ShapeClassifier::UnicharPrintResults(const char *context,
172                                           const std::vector<UnicharRating> &results) const {
173   tprintf("%s\n", context);
174   for (const auto &result : results) {
175     tprintf("%g: c_id=%d=%s", result.rating, result.unichar_id,
176             GetUnicharset().id_to_unichar(result.unichar_id));
177     if (!result.fonts.empty()) {
178       tprintf(" Font Vector:");
179       for (auto font : result.fonts) {
180         tprintf(" %d", font.fontinfo_id);
181       }
182     }
183     tprintf("\n");
184   }
185 }
PrintResults(const char * context,const std::vector<ShapeRating> & results) const186 void ShapeClassifier::PrintResults(const char *context,
187                                    const std::vector<ShapeRating> &results) const {
188   tprintf("%s\n", context);
189   for (const auto &result : results) {
190     tprintf("%g:", result.rating);
191     if (result.joined) {
192       tprintf("[J]");
193     }
194     if (result.broken) {
195       tprintf("[B]");
196     }
197     tprintf(" %s\n", GetShapeTable()->DebugStr(result.shape_id).c_str());
198   }
199 }
200 
201 // Removes any result that has all its unichars covered by a better choice,
202 // regardless of font.
FilterDuplicateUnichars(std::vector<ShapeRating> * results) const203 void ShapeClassifier::FilterDuplicateUnichars(std::vector<ShapeRating> *results) const {
204   std::vector<ShapeRating> filtered_results;
205   // Copy results to filtered results and knock out duplicate unichars.
206   const ShapeTable *shapes = GetShapeTable();
207   for (unsigned r = 0; r < results->size(); ++r) {
208     if (r > 0) {
209       const Shape &shape_r = shapes->GetShape((*results)[r].shape_id);
210       int c;
211       for (c = 0; c < shape_r.size(); ++c) {
212         int unichar_id = shape_r[c].unichar_id;
213         unsigned s;
214         for (s = 0; s < r; ++s) {
215           const Shape &shape_s = shapes->GetShape((*results)[s].shape_id);
216           if (shape_s.ContainsUnichar(unichar_id)) {
217             break; // We found unichar_id.
218           }
219         }
220         if (s == r) {
221           break; // We didn't find unichar_id.
222         }
223       }
224       if (c == shape_r.size()) {
225         continue; // We found all the unichar ids in previous answers.
226       }
227     }
228     filtered_results.push_back((*results)[r]);
229   }
230   *results = filtered_results;
231 }
232 
233 } // namespace tesseract.
234