1 // This is mul/clsfy/clsfy_mean_square_1d_builder.cxx
2 #include <cmath>
3 #include <iostream>
4 #include <string>
5 #include <cstdlib>
6 #include <algorithm>
7 #include "clsfy_mean_square_1d_builder.h"
8 //:
9 // \file
10 // \author dac
11 // \date   Tue Mar  5 01:11:31 2002
12 
13 #include <cassert>
14 #include "vsl/vsl_binary_loader.h"
15 #include "vnl/vnl_double_2.h"
16 #include <clsfy/clsfy_builder_1d.h>
17 #include <clsfy/clsfy_mean_square_1d.h>
18 #ifdef _MSC_VER
19 #  include "vcl_msvc_warnings.h"
20 #endif
21 
22 //=======================================================================
23 
24 clsfy_mean_square_1d_builder::clsfy_mean_square_1d_builder() = default;
25 
26 //=======================================================================
27 
28 clsfy_mean_square_1d_builder::~clsfy_mean_square_1d_builder() = default;
29 
30 //=======================================================================
31 
version_no() const32 short clsfy_mean_square_1d_builder::version_no() const
33 {
34   return 1;
35 }
36 
37 
38 //: Create empty classifier
39 // Caller is responsible for deletion
new_classifier() const40 clsfy_classifier_1d* clsfy_mean_square_1d_builder::new_classifier() const
41 {
42   return new clsfy_mean_square_1d();
43 }
44 
45 
46 //: Build a binary_threshold classifier
47 //  Train classifier, returning weighted error
48 //  Selects parameters of classifier which best separate examples from two classes,
49 //  weighting examples appropriately when estimating the misclassification rate.
50 //  Returns weighted sum of error, e.wts, where e_i =0 for correct classifications,
51 //  e_i=1 for incorrect.
build(clsfy_classifier_1d & classifier,const vnl_vector<double> & egs,const vnl_vector<double> & wts,const std::vector<unsigned> & outputs) const52 double clsfy_mean_square_1d_builder::build(clsfy_classifier_1d& classifier,
53                                            const vnl_vector<double>& egs,
54                                            const vnl_vector<double>& wts,
55                                            const std::vector<unsigned> &outputs) const
56 {
57   // this method sorts the data and passes it to the method below
58   assert(classifier.is_class("clsfy_mean_square_1d"));
59 
60   unsigned int n = egs.size();
61   assert ( wts.size() == n );
62   assert ( outputs.size() == n );
63 
64   // calc weighted mean of positive data
65   double wm_pos= 0.0;
66   double tot_pos_wts=0.0, tot_neg_wts=0.0;
67   unsigned int n_pos=0, n_neg=0;
68   for (unsigned int i=0; i<n; ++i)
69   {
70 #ifdef DEBUG
71     std::cout<<"egs["<<i<<"]= "<<egs[i]<<std::endl
72             <<"wts["<<i<<"]= "<<wts[i]<<std::endl
73             <<"outputs["<<i<<"]= "<<outputs[i]<<std::endl;
74 #endif
75     if ( outputs[i] == 1 )
76     {
77       //std::cout<<"wm_pos= "<<wm_pos<<std::endl;
78       wm_pos+= wts(i)*egs(i);
79       tot_pos_wts+= wts(i);
80       ++n_pos;
81     }
82     else
83     {
84       tot_neg_wts+= wts(i);
85       ++n_neg;
86     }
87   }
88 
89   assert( n_pos+n_neg== n );
90   wm_pos/=tot_pos_wts;
91 #ifdef DEBUG
92   std::cout<<"wm_pos= "<<wm_pos<<std::endl;
93 #endif
94   // create triples data, so can sort
95   std::vector<vbl_triple<double,int,int> > data;
96 
97   vbl_triple<double,int,int> t;
98   // add data to triples
99   for (unsigned int i=0;i<n;++i)
100   {
101     double k= wm_pos-egs[i];
102     t.first=k*k;
103     t.second= outputs[i];
104     t.third = i;
105     data.push_back(t);
106   }
107 
108   vbl_triple<double,int,int> *data_ptr=&data[0];
109   std::sort(data_ptr,data_ptr+n);
110 
111   double wt_pos=0;
112   double wt_neg=0;
113   double min_error= 1000000;
114   double min_thresh= -1;
115   for (unsigned int i=0;i<n;++i)
116   {
117     if ( data[i].second == 0 ) wt_neg+= wts[ data[i].third] ;
118     else if ( data[i].second == 1 ) wt_pos+= wts[ data[i].third];
119     else
120     {
121       std::cout<<"ERROR: clsfy_mean_square_1d_builder::build()\n"
122               <<"Unrecognised output value in triple (ie must be 0 or 1)\n"
123               <<"data.second="<<data[i].second<<std::endl;
124       std::abort();
125     }
126     double error= tot_pos_wts-wt_pos+wt_neg;
127 #ifdef DEBUG
128     std::cout<<"data[i].first= "<<data[i].first<<std::endl
129             <<"data[i].second= "<<data[i].second<<std::endl
130             <<"data[i].third= "<<data[i].third<<std::endl
131 
132             <<"wt_pos= "<<wt_pos<<std::endl
133             <<"tot_wts1= "<<tot_wts1<<std::endl
134             <<"wt_neg= "<<wt_neg<<std::endl
135 
136             <<"error= "<<error<<std::endl;
137 #endif
138     if ( error< min_error )
139     {
140       min_error= error;
141       min_thresh = data[i].first + 0.001 ;
142     }
143   }
144 
145   assert( std::fabs (wt_pos - tot_pos_wts) < 1e-9 );
146   assert( std::fabs (wt_neg - tot_neg_wts) < 1e-9 );
147 #ifdef DEBUG
148   std::cout<<"min_error= "<<min_error<<std::endl
149           <<"min_thresh= "<<min_thresh<<std::endl;
150 #endif
151   // pass parameters to classifier
152   classifier.set_params(vnl_double_2(wm_pos,min_thresh).as_vector());
153   return min_error;
154 }
155 
156 
157 //: Build a mean_square classifier
158 // nb here egs0 are -ve examples
159 // and egs1 are +ve examples
build(clsfy_classifier_1d & classifier,vnl_vector<double> & egs0,vnl_vector<double> & wts0,vnl_vector<double> & egs1,vnl_vector<double> & wts1) const160 double clsfy_mean_square_1d_builder::build(clsfy_classifier_1d& classifier,
161                                            vnl_vector<double>& egs0,
162                                            vnl_vector<double>& wts0,
163                                            vnl_vector<double>& egs1,
164                                            vnl_vector<double>& wts1)  const
165 {
166   // this method sorts the data and passes it to the method below
167   assert(classifier.is_class("clsfy_mean_square_1d"));
168 
169   // find mean of positive data (ie egs1) then calc square distance from mean
170   // for each example
171   unsigned int n0 = egs0.size();
172   unsigned int n1 = egs1.size();
173   assert (wts0.size() == n0 );
174   assert (wts1.size() == n1 );
175 
176   // calc weighted mean of positive data
177   double tot_wts1= wts1.mean()*n1;
178   double wm_pos=0.0;
179   for (unsigned int i=0; i< n1; ++i)
180   {
181     wm_pos+= wts1(i)*egs1(i);
182 #ifdef DEBUG
183     std::cout<<"egs1("<<i<<")= "<<egs1(i)<<std::endl
184             <<"wts1("<<i<<")= "<<wts1(i)<<std::endl;
185 #endif
186   }
187   wm_pos/=tot_wts1;
188 
189   std::cout<<"wm_pos= "<<wm_pos<<std::endl;
190 
191   std::vector<vbl_triple<double,int,int> > data;
192 
193   vnl_vector<double> wts(n0+n1);
194   vbl_triple<double,int,int> t;
195   // add data for class 0
196   for (unsigned int i=0;i<n0;++i)
197   {
198     double k= wm_pos-egs0[i];
199     t.first=k*k;
200     t.second=0;
201     t.third = i;
202     wts(i)= wts0[i];
203     data.push_back(t);
204   }
205 
206   // add data for class 1
207   for (unsigned int i=0;i<n1;++i)
208   {
209     double k= wm_pos-egs1[i];
210     t.first=k*k;
211     t.second=1;
212     t.third = i+n0;
213     wts(i+n0)= wts1[i];
214     data.push_back(t);
215   }
216 
217   unsigned int n=n0+n1;
218 
219   vbl_triple<double,int,int> *data_ptr=&data[0];
220   std::sort(data_ptr,data_ptr+n);
221 
222   double wt_pos=0;
223   double wt_neg=0;
224   double min_error= 1000000;
225   double min_thresh= -1;
226   for (unsigned int i=0;i<n;++i)
227   {
228     if ( data[i].second == 0 ) wt_neg+= wts[ data[i].third] ;
229     else if ( data[i].second == 1 ) wt_pos+= wts[ data[i].third];
230     else
231     {
232       std::cout<<"ERROR: clsfy_mean_square_1d_builder::build()\n"
233               <<"Unrecognised output value in triple (ie must be 0 or 1)\n"
234               <<"data.second="<<data[i].second<<std::endl;
235       std::abort();
236     }
237     double error= tot_wts1-wt_pos+wt_neg;
238 #ifdef DEBUG
239     std::cout<<"data[i].first= "<<data[i].first<<std::endl
240             <<"data[i].second= "<<data[i].second<<std::endl
241             <<"data[i].third= "<<data[i].third<<std::endl
242 
243             <<"wt_pos= "<<wt_pos<<std::endl
244             <<"tot_wts1= "<<tot_wts1<<std::endl
245             <<"wt_neg= "<<wt_neg<<std::endl
246 
247             <<"error= "<<error<<std::endl;
248 #endif
249     if ( error< min_error )
250     {
251       min_error= error;
252       min_thresh = data[i].first + 0.001 ;
253     }
254   }
255 
256   assert( std::fabs (wt_pos - tot_wts1) < 1e-9 );
257   assert( std::fabs (wt_neg - wts0.mean()*n0) < 1e-9 );
258   std::cout<<"min_error= "<<min_error<<std::endl
259           <<"min_thresh= "<<min_thresh<<std::endl;
260 
261   // pass parameters to classifier
262   classifier.set_params(vnl_double_2(wm_pos,min_thresh).as_vector());
263   return min_error;
264 }
265 
266 
267 //: Train classifier, returning weighted error
268 //   Assumes two classes
build_from_sorted_data(clsfy_classifier_1d &,const vbl_triple<double,int,int> *,const vnl_vector<double> &) const269 double clsfy_mean_square_1d_builder::build_from_sorted_data(
270                                   clsfy_classifier_1d& /*classifier*/,
271                                   const vbl_triple<double,int,int>* /*data*/,
272                                   const vnl_vector<double>& /*wts*/
273                                   ) const
274 {
275   std::cout<<"ERROR: clsfy_mean_square_1d_builder::build_from_sorted_data()\n"
276           <<"Function not implemented because can't use pre-sorted data\n"
277           <<"the weighted mean of the data is needed to calc the ordering!\n";
278   std::abort();
279 
280   return 0.0;
281 }
282 
283 //=======================================================================
284 
is_a() const285 std::string clsfy_mean_square_1d_builder::is_a() const
286 {
287   return std::string("clsfy_mean_square_1d_builder");
288 }
289 
is_class(std::string const & s) const290 bool clsfy_mean_square_1d_builder::is_class(std::string const& s) const
291 {
292   return s == clsfy_mean_square_1d_builder::is_a() || clsfy_builder_1d::is_class(s);
293 }
294 
295 //=======================================================================
296 
297 #if 0 // two functions commented out
298 
299 // required if data stored on the heap is present in this derived class
300 clsfy_mean_square_1d_builder::clsfy_mean_square_1d_builder(
301                              const clsfy_mean_square_1d_builder& new_b) :
302   data_ptr_(0)
303 {
304   *this = new_b;
305 }
306 
307 //=======================================================================
308 
309 // required if data stored on the heap is present in this derived class
310 clsfy_mean_square_1d_builder&
311 clsfy_mean_square_1d_builder::operator=(const clsfy_mean_square_1d_builder& new_b)
312 {
313   if (&new_b==this) return *this;
314 
315   // Copy heap member variables.
316   delete data_ptr_; data_ptr_=0;
317 
318   if (new_b.data_ptr_)
319     data_ptr_ = new_b.data_ptr_->clone();
320 
321   // Copy normal member variables
322   data_ = new_b.data_;
323 
324   return *this;
325 }
326 
327 #endif // 0
328 
329 //=======================================================================
330 
clone() const331 clsfy_builder_1d* clsfy_mean_square_1d_builder::clone() const
332 {
333   return new clsfy_mean_square_1d_builder(*this);
334 }
335 
336 //=======================================================================
337 
338 // required if data is present in this base class
print_summary(std::ostream &) const339 void clsfy_mean_square_1d_builder::print_summary(std::ostream& /*os*/) const
340 {
341   // clsfy_builder_1d::print_summary(os); // Uncomment this line if it has one.
342   // vsl_print_summary(os, data_); // Example of data output
343 
344   std::cerr << "clsfy_mean_square_1d_builder::print_summary() NYI\n";
345 }
346 
347 //=======================================================================
348 
349 // required if data is present in this base class
b_write(vsl_b_ostream &) const350 void clsfy_mean_square_1d_builder::b_write(vsl_b_ostream& /*bfs*/) const
351 {
352   //vsl_b_write(bfs, version_no());
353   //clsfy_builder_1d::b_write(bfs);  // Needed if base has any data
354   //vsl_b_write(bfs, data_);
355   std::cerr << "clsfy_mean_square_1d_builder::b_write() NYI\n";
356 }
357 
358 //=======================================================================
359 
360 // required if data is present in this base class
b_read(vsl_b_istream &)361 void clsfy_mean_square_1d_builder::b_read(vsl_b_istream& /*bfs*/)
362 {
363   std::cerr << "clsfy_mean_square_1d_builder::b_read() NYI\n";
364 #if 0
365   if (!bfs) return;
366 
367   short version;
368   vsl_b_read(bfs,version);
369   switch (version)
370   {
371   case (1):
372     //clsfy_builder_1d::b_read(bfs);  // Needed if base has any data
373     vsl_b_read(bfs,data_);
374     break;
375   default:
376     std::cerr << "I/O ERROR: vsl_b_read(vsl_b_istream&, clsfy_mean_square_1d_builder&)\n"
377              << "           Unknown version number "<< version << '\n';
378     bfs.is().clear(std::ios::badbit); // Set an unrecoverable IO error on stream
379     return;
380   }
381 #endif
382 }
383