1 // Copyright 2011 Google Inc. All Rights Reserved.
2 // Author: rays@google.com (Ray Smith)
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 //
14 ///////////////////////////////////////////////////////////////////////
15
16 #include "sampleiterator.h"
17
18 #include "intfeaturemap.h"
19
20 #include "indexmapbidi.h"
21 #include "shapetable.h"
22 #include "trainingsample.h"
23 #include "trainingsampleset.h"
24
25 namespace tesseract {
26
27 // ================== SampleIterator Implementation =================
28
SampleIterator()29 SampleIterator::SampleIterator()
30 : charset_map_(nullptr)
31 , shape_table_(nullptr)
32 , sample_set_(nullptr)
33 , randomize_(false)
34 , owned_shape_table_(nullptr) {
35 num_shapes_ = 0;
36 Begin();
37 }
38
~SampleIterator()39 SampleIterator::~SampleIterator() {
40 Clear();
41 }
42
Clear()43 void SampleIterator::Clear() {
44 delete owned_shape_table_;
45 owned_shape_table_ = nullptr;
46 }
47
48 // See class comment for arguments.
Init(const IndexMapBiDi * charset_map,const ShapeTable * shape_table,bool randomize,TrainingSampleSet * sample_set)49 void SampleIterator::Init(const IndexMapBiDi *charset_map, const ShapeTable *shape_table,
50 bool randomize, TrainingSampleSet *sample_set) {
51 Clear();
52 charset_map_ = charset_map;
53 shape_table_ = shape_table;
54 sample_set_ = sample_set;
55 randomize_ = randomize;
56 if (shape_table_ == nullptr && charset_map_ != nullptr) {
57 // The caller wishes to iterate by class. The easiest way to do this
58 // is to create a dummy shape_table_ that we will own.
59 int num_fonts = sample_set_->NumFonts();
60 owned_shape_table_ = new ShapeTable(sample_set_->unicharset());
61 int charsetsize = sample_set_->unicharset().size();
62 for (int c = 0; c < charsetsize; ++c) {
63 // We always add a shape for each character to keep the index in sync
64 // with the unichar_id.
65 int shape_id = owned_shape_table_->AddShape(c, 0);
66 for (int f = 1; f < num_fonts; ++f) {
67 if (sample_set_->NumClassSamples(f, c, true) > 0) {
68 owned_shape_table_->AddToShape(shape_id, c, f);
69 }
70 }
71 }
72 shape_table_ = owned_shape_table_;
73 }
74 if (shape_table_ != nullptr) {
75 num_shapes_ = shape_table_->NumShapes();
76 } else {
77 num_shapes_ = randomize ? sample_set_->num_samples() : sample_set_->num_raw_samples();
78 }
79 Begin();
80 }
81
82 // Iterator functions designed for use with a simple for loop:
83 // for (it.Begin(); !it.AtEnd(); it.Next()) {
84 // const TrainingSample& sample = it.GetSample();
85 // }
Begin()86 void SampleIterator::Begin() {
87 shape_index_ = -1;
88 shape_char_index_ = 0;
89 num_shape_chars_ = 0;
90 shape_font_index_ = 0;
91 num_shape_fonts_ = 0;
92 sample_index_ = 0;
93 num_samples_ = 0;
94 // Find the first indexable sample.
95 Next();
96 }
97
AtEnd() const98 bool SampleIterator::AtEnd() const {
99 return shape_index_ >= num_shapes_;
100 }
101
GetSample() const102 const TrainingSample &SampleIterator::GetSample() const {
103 if (shape_table_ != nullptr) {
104 const UnicharAndFonts *shape_entry = GetShapeEntry();
105 int char_id = shape_entry->unichar_id;
106 int font_id = shape_entry->font_ids[shape_font_index_];
107 return *sample_set_->GetSample(font_id, char_id, sample_index_);
108 } else {
109 return *sample_set_->GetSample(shape_index_);
110 }
111 }
112
MutableSample() const113 TrainingSample *SampleIterator::MutableSample() const {
114 if (shape_table_ != nullptr) {
115 const UnicharAndFonts *shape_entry = GetShapeEntry();
116 int char_id = shape_entry->unichar_id;
117 int font_id = shape_entry->font_ids[shape_font_index_];
118 return sample_set_->MutableSample(font_id, char_id, sample_index_);
119 } else {
120 return sample_set_->mutable_sample(shape_index_);
121 }
122 }
123
124 // Returns the total index (from the original set of samples) of the current
125 // sample.
GlobalSampleIndex() const126 int SampleIterator::GlobalSampleIndex() const {
127 if (shape_table_ != nullptr) {
128 const UnicharAndFonts *shape_entry = GetShapeEntry();
129 int char_id = shape_entry->unichar_id;
130 int font_id = shape_entry->font_ids[shape_font_index_];
131 return sample_set_->GlobalSampleIndex(font_id, char_id, sample_index_);
132 } else {
133 return shape_index_;
134 }
135 }
136
137 // Returns the index of the current sample in compact charset space, so
138 // in a 2-class problem between x and y, the returned indices will all be
139 // 0 or 1, and have nothing to do with the unichar_ids.
140 // If the charset_map_ is nullptr, then this is equal to GetSparseClassID().
GetCompactClassID() const141 int SampleIterator::GetCompactClassID() const {
142 return charset_map_ != nullptr ? charset_map_->SparseToCompact(shape_index_) : GetSparseClassID();
143 }
144 // Returns the index of the current sample in sparse charset space, so
145 // in a 2-class problem between x and y, the returned indices will all be
146 // x or y, where x and y may be unichar_ids (no shape_table_) or shape_ids
147 // with a shape_table_.
GetSparseClassID() const148 int SampleIterator::GetSparseClassID() const {
149 return shape_table_ != nullptr ? shape_index_ : GetSample().class_id();
150 }
151
152 // Moves on to the next indexable sample. If the end is reached, leaves
153 // the state such that AtEnd() is true.
Next()154 void SampleIterator::Next() {
155 if (shape_table_ != nullptr) {
156 // Next sample in this class/font combination.
157 ++sample_index_;
158 if (sample_index_ < num_samples_) {
159 return;
160 }
161 // Next font in this class in this shape.
162 sample_index_ = 0;
163 do {
164 ++shape_font_index_;
165 if (shape_font_index_ >= num_shape_fonts_) {
166 // Next unichar in this shape.
167 shape_font_index_ = 0;
168 ++shape_char_index_;
169 if (shape_char_index_ >= num_shape_chars_) {
170 // Find the next shape that is mapped in the charset_map_.
171 shape_char_index_ = 0;
172 do {
173 ++shape_index_;
174 } while (shape_index_ < num_shapes_ && charset_map_ != nullptr &&
175 charset_map_->SparseToCompact(shape_index_) < 0);
176 if (shape_index_ >= num_shapes_) {
177 return; // The end.
178 }
179 num_shape_chars_ = shape_table_->GetShape(shape_index_).size();
180 }
181 }
182 const UnicharAndFonts *shape_entry = GetShapeEntry();
183 num_shape_fonts_ = shape_entry->font_ids.size();
184 int char_id = shape_entry->unichar_id;
185 int font_id = shape_entry->font_ids[shape_font_index_];
186 num_samples_ = sample_set_->NumClassSamples(font_id, char_id, randomize_);
187 } while (num_samples_ == 0);
188 } else {
189 // We are just iterating over the samples.
190 ++shape_index_;
191 }
192 }
193
194 // Returns the size of the compact charset space.
CompactCharsetSize() const195 int SampleIterator::CompactCharsetSize() const {
196 return charset_map_ != nullptr ? charset_map_->CompactSize() : SparseCharsetSize();
197 }
198
199 // Returns the size of the sparse charset space.
SparseCharsetSize() const200 int SampleIterator::SparseCharsetSize() const {
201 return charset_map_ != nullptr
202 ? charset_map_->SparseSize()
203 : (shape_table_ != nullptr ? shape_table_->NumShapes() : sample_set_->charsetsize());
204 }
205
206 // Sets the mapped_features_ from the features using the provided
207 // feature_map.
MapFeatures(TrainingSample & s,const IntFeatureMap & feature_map)208 static void MapFeatures(TrainingSample &s, const IntFeatureMap &feature_map) {
209 std::vector<int> indexed_features;
210 feature_map.feature_space().IndexAndSortFeatures(s.features(), s.num_features(),
211 &indexed_features);
212 feature_map.MapIndexedFeatures(indexed_features, &s.mapped_features_);
213 s.features_are_indexed_ = false;
214 s.features_are_mapped_ = true;
215 }
216
217 // Apply the supplied feature_space/feature_map transform to all samples
218 // accessed by this iterator.
MapSampleFeatures(const IntFeatureMap & feature_map)219 void SampleIterator::MapSampleFeatures(const IntFeatureMap &feature_map) {
220 for (Begin(); !AtEnd(); Next()) {
221 TrainingSample *sample = MutableSample();
222 MapFeatures(*sample, feature_map);
223 }
224 }
225
226 // Adjust the weights of all the samples to be uniform in the given charset.
227 // Returns the number of samples in the iterator.
UniformSamples()228 int SampleIterator::UniformSamples() {
229 int num_good_samples = 0;
230 for (Begin(); !AtEnd(); Next()) {
231 TrainingSample *sample = MutableSample();
232 sample->set_weight(1.0);
233 ++num_good_samples;
234 }
235 NormalizeSamples();
236 return num_good_samples;
237 }
238
239 // Normalize the weights of all the samples in the charset_map so they sum
240 // to 1. Returns the minimum assigned sample weight.
NormalizeSamples()241 double SampleIterator::NormalizeSamples() {
242 double total_weight = 0.0;
243 int sample_count = 0;
244 for (Begin(); !AtEnd(); Next()) {
245 const TrainingSample &sample = GetSample();
246 total_weight += sample.weight();
247 ++sample_count;
248 }
249 // Normalize samples.
250 double min_assigned_sample_weight = 1.0;
251 if (total_weight > 0.0) {
252 for (Begin(); !AtEnd(); Next()) {
253 TrainingSample *sample = MutableSample();
254 double weight = sample->weight() / total_weight;
255 if (weight < min_assigned_sample_weight) {
256 min_assigned_sample_weight = weight;
257 }
258 sample->set_weight(weight);
259 }
260 }
261 return min_assigned_sample_weight;
262 }
263
264 // Helper returns the current UnicharAndFont shape_entry.
GetShapeEntry() const265 const UnicharAndFonts *SampleIterator::GetShapeEntry() const {
266 const Shape &shape = shape_table_->GetShape(shape_index_);
267 return &shape[shape_char_index_];
268 }
269
270 } // namespace tesseract.
271