1 // This is mul/clsfy/clsfy_mean_square_1d_builder.cxx
2 #include <cmath>
3 #include <iostream>
4 #include <string>
5 #include <cstdlib>
6 #include <algorithm>
7 #include "clsfy_mean_square_1d_builder.h"
8 //:
9 // \file
10 // \author dac
11 // \date Tue Mar 5 01:11:31 2002
12
13 #include <cassert>
14 #include "vsl/vsl_binary_loader.h"
15 #include "vnl/vnl_double_2.h"
16 #include <clsfy/clsfy_builder_1d.h>
17 #include <clsfy/clsfy_mean_square_1d.h>
18 #ifdef _MSC_VER
19 # include "vcl_msvc_warnings.h"
20 #endif
21
22 //=======================================================================
23
24 clsfy_mean_square_1d_builder::clsfy_mean_square_1d_builder() = default;
25
26 //=======================================================================
27
28 clsfy_mean_square_1d_builder::~clsfy_mean_square_1d_builder() = default;
29
30 //=======================================================================
31
version_no() const32 short clsfy_mean_square_1d_builder::version_no() const
33 {
34 return 1;
35 }
36
37
38 //: Create empty classifier
39 // Caller is responsible for deletion
new_classifier() const40 clsfy_classifier_1d* clsfy_mean_square_1d_builder::new_classifier() const
41 {
42 return new clsfy_mean_square_1d();
43 }
44
45
46 //: Build a binary_threshold classifier
47 // Train classifier, returning weighted error
48 // Selects parameters of classifier which best separate examples from two classes,
49 // weighting examples appropriately when estimating the misclassification rate.
50 // Returns weighted sum of error, e.wts, where e_i =0 for correct classifications,
51 // e_i=1 for incorrect.
build(clsfy_classifier_1d & classifier,const vnl_vector<double> & egs,const vnl_vector<double> & wts,const std::vector<unsigned> & outputs) const52 double clsfy_mean_square_1d_builder::build(clsfy_classifier_1d& classifier,
53 const vnl_vector<double>& egs,
54 const vnl_vector<double>& wts,
55 const std::vector<unsigned> &outputs) const
56 {
57 // this method sorts the data and passes it to the method below
58 assert(classifier.is_class("clsfy_mean_square_1d"));
59
60 unsigned int n = egs.size();
61 assert ( wts.size() == n );
62 assert ( outputs.size() == n );
63
64 // calc weighted mean of positive data
65 double wm_pos= 0.0;
66 double tot_pos_wts=0.0, tot_neg_wts=0.0;
67 unsigned int n_pos=0, n_neg=0;
68 for (unsigned int i=0; i<n; ++i)
69 {
70 #ifdef DEBUG
71 std::cout<<"egs["<<i<<"]= "<<egs[i]<<std::endl
72 <<"wts["<<i<<"]= "<<wts[i]<<std::endl
73 <<"outputs["<<i<<"]= "<<outputs[i]<<std::endl;
74 #endif
75 if ( outputs[i] == 1 )
76 {
77 //std::cout<<"wm_pos= "<<wm_pos<<std::endl;
78 wm_pos+= wts(i)*egs(i);
79 tot_pos_wts+= wts(i);
80 ++n_pos;
81 }
82 else
83 {
84 tot_neg_wts+= wts(i);
85 ++n_neg;
86 }
87 }
88
89 assert( n_pos+n_neg== n );
90 wm_pos/=tot_pos_wts;
91 #ifdef DEBUG
92 std::cout<<"wm_pos= "<<wm_pos<<std::endl;
93 #endif
94 // create triples data, so can sort
95 std::vector<vbl_triple<double,int,int> > data;
96
97 vbl_triple<double,int,int> t;
98 // add data to triples
99 for (unsigned int i=0;i<n;++i)
100 {
101 double k= wm_pos-egs[i];
102 t.first=k*k;
103 t.second= outputs[i];
104 t.third = i;
105 data.push_back(t);
106 }
107
108 vbl_triple<double,int,int> *data_ptr=&data[0];
109 std::sort(data_ptr,data_ptr+n);
110
111 double wt_pos=0;
112 double wt_neg=0;
113 double min_error= 1000000;
114 double min_thresh= -1;
115 for (unsigned int i=0;i<n;++i)
116 {
117 if ( data[i].second == 0 ) wt_neg+= wts[ data[i].third] ;
118 else if ( data[i].second == 1 ) wt_pos+= wts[ data[i].third];
119 else
120 {
121 std::cout<<"ERROR: clsfy_mean_square_1d_builder::build()\n"
122 <<"Unrecognised output value in triple (ie must be 0 or 1)\n"
123 <<"data.second="<<data[i].second<<std::endl;
124 std::abort();
125 }
126 double error= tot_pos_wts-wt_pos+wt_neg;
127 #ifdef DEBUG
128 std::cout<<"data[i].first= "<<data[i].first<<std::endl
129 <<"data[i].second= "<<data[i].second<<std::endl
130 <<"data[i].third= "<<data[i].third<<std::endl
131
132 <<"wt_pos= "<<wt_pos<<std::endl
133 <<"tot_wts1= "<<tot_wts1<<std::endl
134 <<"wt_neg= "<<wt_neg<<std::endl
135
136 <<"error= "<<error<<std::endl;
137 #endif
138 if ( error< min_error )
139 {
140 min_error= error;
141 min_thresh = data[i].first + 0.001 ;
142 }
143 }
144
145 assert( std::fabs (wt_pos - tot_pos_wts) < 1e-9 );
146 assert( std::fabs (wt_neg - tot_neg_wts) < 1e-9 );
147 #ifdef DEBUG
148 std::cout<<"min_error= "<<min_error<<std::endl
149 <<"min_thresh= "<<min_thresh<<std::endl;
150 #endif
151 // pass parameters to classifier
152 classifier.set_params(vnl_double_2(wm_pos,min_thresh).as_vector());
153 return min_error;
154 }
155
156
157 //: Build a mean_square classifier
158 // nb here egs0 are -ve examples
159 // and egs1 are +ve examples
build(clsfy_classifier_1d & classifier,vnl_vector<double> & egs0,vnl_vector<double> & wts0,vnl_vector<double> & egs1,vnl_vector<double> & wts1) const160 double clsfy_mean_square_1d_builder::build(clsfy_classifier_1d& classifier,
161 vnl_vector<double>& egs0,
162 vnl_vector<double>& wts0,
163 vnl_vector<double>& egs1,
164 vnl_vector<double>& wts1) const
165 {
166 // this method sorts the data and passes it to the method below
167 assert(classifier.is_class("clsfy_mean_square_1d"));
168
169 // find mean of positive data (ie egs1) then calc square distance from mean
170 // for each example
171 unsigned int n0 = egs0.size();
172 unsigned int n1 = egs1.size();
173 assert (wts0.size() == n0 );
174 assert (wts1.size() == n1 );
175
176 // calc weighted mean of positive data
177 double tot_wts1= wts1.mean()*n1;
178 double wm_pos=0.0;
179 for (unsigned int i=0; i< n1; ++i)
180 {
181 wm_pos+= wts1(i)*egs1(i);
182 #ifdef DEBUG
183 std::cout<<"egs1("<<i<<")= "<<egs1(i)<<std::endl
184 <<"wts1("<<i<<")= "<<wts1(i)<<std::endl;
185 #endif
186 }
187 wm_pos/=tot_wts1;
188
189 std::cout<<"wm_pos= "<<wm_pos<<std::endl;
190
191 std::vector<vbl_triple<double,int,int> > data;
192
193 vnl_vector<double> wts(n0+n1);
194 vbl_triple<double,int,int> t;
195 // add data for class 0
196 for (unsigned int i=0;i<n0;++i)
197 {
198 double k= wm_pos-egs0[i];
199 t.first=k*k;
200 t.second=0;
201 t.third = i;
202 wts(i)= wts0[i];
203 data.push_back(t);
204 }
205
206 // add data for class 1
207 for (unsigned int i=0;i<n1;++i)
208 {
209 double k= wm_pos-egs1[i];
210 t.first=k*k;
211 t.second=1;
212 t.third = i+n0;
213 wts(i+n0)= wts1[i];
214 data.push_back(t);
215 }
216
217 unsigned int n=n0+n1;
218
219 vbl_triple<double,int,int> *data_ptr=&data[0];
220 std::sort(data_ptr,data_ptr+n);
221
222 double wt_pos=0;
223 double wt_neg=0;
224 double min_error= 1000000;
225 double min_thresh= -1;
226 for (unsigned int i=0;i<n;++i)
227 {
228 if ( data[i].second == 0 ) wt_neg+= wts[ data[i].third] ;
229 else if ( data[i].second == 1 ) wt_pos+= wts[ data[i].third];
230 else
231 {
232 std::cout<<"ERROR: clsfy_mean_square_1d_builder::build()\n"
233 <<"Unrecognised output value in triple (ie must be 0 or 1)\n"
234 <<"data.second="<<data[i].second<<std::endl;
235 std::abort();
236 }
237 double error= tot_wts1-wt_pos+wt_neg;
238 #ifdef DEBUG
239 std::cout<<"data[i].first= "<<data[i].first<<std::endl
240 <<"data[i].second= "<<data[i].second<<std::endl
241 <<"data[i].third= "<<data[i].third<<std::endl
242
243 <<"wt_pos= "<<wt_pos<<std::endl
244 <<"tot_wts1= "<<tot_wts1<<std::endl
245 <<"wt_neg= "<<wt_neg<<std::endl
246
247 <<"error= "<<error<<std::endl;
248 #endif
249 if ( error< min_error )
250 {
251 min_error= error;
252 min_thresh = data[i].first + 0.001 ;
253 }
254 }
255
256 assert( std::fabs (wt_pos - tot_wts1) < 1e-9 );
257 assert( std::fabs (wt_neg - wts0.mean()*n0) < 1e-9 );
258 std::cout<<"min_error= "<<min_error<<std::endl
259 <<"min_thresh= "<<min_thresh<<std::endl;
260
261 // pass parameters to classifier
262 classifier.set_params(vnl_double_2(wm_pos,min_thresh).as_vector());
263 return min_error;
264 }
265
266
267 //: Train classifier, returning weighted error
268 // Assumes two classes
build_from_sorted_data(clsfy_classifier_1d &,const vbl_triple<double,int,int> *,const vnl_vector<double> &) const269 double clsfy_mean_square_1d_builder::build_from_sorted_data(
270 clsfy_classifier_1d& /*classifier*/,
271 const vbl_triple<double,int,int>* /*data*/,
272 const vnl_vector<double>& /*wts*/
273 ) const
274 {
275 std::cout<<"ERROR: clsfy_mean_square_1d_builder::build_from_sorted_data()\n"
276 <<"Function not implemented because can't use pre-sorted data\n"
277 <<"the weighted mean of the data is needed to calc the ordering!\n";
278 std::abort();
279
280 return 0.0;
281 }
282
283 //=======================================================================
284
is_a() const285 std::string clsfy_mean_square_1d_builder::is_a() const
286 {
287 return std::string("clsfy_mean_square_1d_builder");
288 }
289
is_class(std::string const & s) const290 bool clsfy_mean_square_1d_builder::is_class(std::string const& s) const
291 {
292 return s == clsfy_mean_square_1d_builder::is_a() || clsfy_builder_1d::is_class(s);
293 }
294
295 //=======================================================================
296
297 #if 0 // two functions commented out
298
299 // required if data stored on the heap is present in this derived class
300 clsfy_mean_square_1d_builder::clsfy_mean_square_1d_builder(
301 const clsfy_mean_square_1d_builder& new_b) :
302 data_ptr_(0)
303 {
304 *this = new_b;
305 }
306
307 //=======================================================================
308
309 // required if data stored on the heap is present in this derived class
310 clsfy_mean_square_1d_builder&
311 clsfy_mean_square_1d_builder::operator=(const clsfy_mean_square_1d_builder& new_b)
312 {
313 if (&new_b==this) return *this;
314
315 // Copy heap member variables.
316 delete data_ptr_; data_ptr_=0;
317
318 if (new_b.data_ptr_)
319 data_ptr_ = new_b.data_ptr_->clone();
320
321 // Copy normal member variables
322 data_ = new_b.data_;
323
324 return *this;
325 }
326
327 #endif // 0
328
329 //=======================================================================
330
clone() const331 clsfy_builder_1d* clsfy_mean_square_1d_builder::clone() const
332 {
333 return new clsfy_mean_square_1d_builder(*this);
334 }
335
336 //=======================================================================
337
338 // required if data is present in this base class
print_summary(std::ostream &) const339 void clsfy_mean_square_1d_builder::print_summary(std::ostream& /*os*/) const
340 {
341 // clsfy_builder_1d::print_summary(os); // Uncomment this line if it has one.
342 // vsl_print_summary(os, data_); // Example of data output
343
344 std::cerr << "clsfy_mean_square_1d_builder::print_summary() NYI\n";
345 }
346
347 //=======================================================================
348
349 // required if data is present in this base class
b_write(vsl_b_ostream &) const350 void clsfy_mean_square_1d_builder::b_write(vsl_b_ostream& /*bfs*/) const
351 {
352 //vsl_b_write(bfs, version_no());
353 //clsfy_builder_1d::b_write(bfs); // Needed if base has any data
354 //vsl_b_write(bfs, data_);
355 std::cerr << "clsfy_mean_square_1d_builder::b_write() NYI\n";
356 }
357
358 //=======================================================================
359
360 // required if data is present in this base class
b_read(vsl_b_istream &)361 void clsfy_mean_square_1d_builder::b_read(vsl_b_istream& /*bfs*/)
362 {
363 std::cerr << "clsfy_mean_square_1d_builder::b_read() NYI\n";
364 #if 0
365 if (!bfs) return;
366
367 short version;
368 vsl_b_read(bfs,version);
369 switch (version)
370 {
371 case (1):
372 //clsfy_builder_1d::b_read(bfs); // Needed if base has any data
373 vsl_b_read(bfs,data_);
374 break;
375 default:
376 std::cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, clsfy_mean_square_1d_builder&)\n"
377 << " Unknown version number "<< version << '\n';
378 bfs.is().clear(std::ios::badbit); // Set an unrecoverable IO error on stream
379 return;
380 }
381 #endif
382 }
383