1 #include <iostream>
2 #include <algorithm>
3 #include <limits>
4 #include <cmath>
5 #include <sstream>
6 #include "brec_part_hierarchy_learner.h"
7 //:
8 // \file
9 // \author Ozge C Ozcanli (ozge@lems.brown.edu)
10 // \date Jan 19, 2009
11 
12 #include "brec_part_hierarchy_learner_sptr.h"
13 
14 #include <bsta/bsta_histogram.h>
15 #include <brec/brec_part_gaussian.h>
16 #include <brec/brec_part_hierarchy.h>
17 #include <brec/brec_part_hierarchy_detector.h>
18 
19 #include "vil/vil_image_view.h"
20 #ifdef _MSC_VER
21 #  include "vcl_msvc_warnings.h"
22 #endif
23 
24 //: check the "true" part of the mask
check_equal(vbl_array_2d<bool> & left_array,vbl_array_2d<bool> & right_array)25 bool check_equal(vbl_array_2d<bool>& left_array, vbl_array_2d<bool>& right_array)
26 {
27   vbl_array_2d<bool>::size_type i1, j1;
28 
29   // first check if they have the same number of true pixels
30   int cnt_left = 0;
31   for (i1 = 0; i1 < left_array.rows(); i1++) {
32     for (j1 = 0; j1 < left_array.cols(); j1++) {
33       if (left_array.get_rows()[i1][j1])
34         cnt_left++;
35     }
36   }
37   int cnt_right = 0;
38   for (i1 = 0; i1 < right_array.rows(); i1++) {
39     for (j1 = 0; j1 < right_array.cols(); j1++) {
40       if (right_array.get_rows()[i1][j1])
41         cnt_right++;
42     }
43   }
44   if (cnt_left != cnt_right)
45     return false;
46 
47   // find the first "true" pixel on both masks and use that pixel to centralize the second part on top of the other one
48   bool found_it = false;
49   for (i1 = 0; i1 < left_array.rows(); i1++) {
50     for (j1 = 0; j1 < left_array.cols(); j1++) {
51       if (left_array.get_rows()[i1][j1]) {
52         found_it = true;
53         break;
54       }
55     }
56     if (found_it)
57       break;
58   }
59   if (!found_it)
60     return false;
61 
62   vbl_array_2d<bool>::size_type i2, j2;
63   found_it = false;
64   for (i2 = 0; i2 < right_array.rows(); i2++) {
65     for (j2 = 0; j2 < right_array.cols(); j2++) {
66       if (right_array.get_rows()[i2][j2]) {
67         found_it = true;
68         break;
69       }
70     }
71     if (found_it)
72       break;
73   }
74   if (!found_it)
75     return false;
76 
77   // now we think i1,j1 and i2,j2 correspond
78   vbl_array_2d<bool>::size_type i11, j11;
79   for (i11 = 0; i11 < left_array.rows(); i11++) {
80     for (j11 = 0; j11 < left_array.cols(); j11++) {
81       if (left_array.get_rows()[i11][j11]) {
82         // we expect right_array to have a true value at the corresponding spot when centered on i1, j2, otherwise they're not equal
83         int i22, j22;
84         i22 = i11-i1+i2;
85         j22 = j11-j1+j2;
86         if (i22 < 0 || i22 >= (int)right_array.rows() || j22 < 0 || j22 >= (int)right_array.cols() || !right_array.get_rows()[i22][j22])
87           return false;
88       }
89     }
90   }
91   return true;
92 }
93 
initialize_layer0_as_gaussians(int ndirs,float lambda_range,float lambda_inc,int n)94 void brec_part_hierarchy_learner::initialize_layer0_as_gaussians(int ndirs, float lambda_range, float lambda_inc, int n)
95 {
96   //n_ = (int)(lambda_range/lambda_inc);
97   n_ = n; // visualization parameter to plot histograms into m files
98   // stats for bright operators
99   float theta = 0.0f;
100   float theta_inc = 180.0f/(float)ndirs;
101   unsigned type_cnt = 0;
102 
103   std::vector<vbl_array_2d<bool> > masks;
104 
105   for (float lambda0 = lambda_inc; lambda0 <= lambda_range; lambda0 += lambda_inc)
106   {
107     for (float lambda1 = lambda_inc; lambda1 <= lambda_range; lambda1 += lambda_inc)
108     {
109       theta = 0.0f;
110       for (int i = 0; i < ndirs; i++) {
111         // initialize bright operator
112         float adjusted_theta = theta > 90 ? -(180-theta) : theta;  // operators work in [-pi, pi] range
113         brec_part_gaussian_sptr p = new brec_part_gaussian(0.0f, 0.0f, 0.0f, lambda0, lambda1, adjusted_theta, true, type_cnt);
114         // do not add it if it has exactly the same one as a previous operator
115         bool already_exists = false;
116         for (auto & mask : masks) {
117           if (check_equal(p->mask_, mask)) {
118             already_exists = true;
119             break;
120           }
121         }
122 
123         if (already_exists)
124           continue;
125 
126         masks.push_back(p->mask_);
127 
128         // create histogram for foreground stats
129         auto* h = new bsta_histogram<float>(0.0f, 2.0f, 100); // was (-7.0f, 1.0f, 32)
130         std::pair<brec_part_instance_sptr, bsta_histogram<float>* > pa(p->cast_to_instance(), h);
131         stats_layer0_.push_back(pa);
132         type_cnt++;
133 
134         // initialize the dark operator as well
135         brec_part_gaussian_sptr pd = new brec_part_gaussian(0.0f, 0.0f, 0.0f, lambda0, lambda1, adjusted_theta, false, type_cnt);
136         auto* hd = new bsta_histogram<float>(0.0f, 2.0f, 100);
137         std::pair<brec_part_instance_sptr, bsta_histogram<float>* > pad(pd->cast_to_instance(), hd);
138         stats_layer0_.push_back(pad);
139         type_cnt++;
140 
141         theta += theta_inc;
142       }
143     }
144   }
145   std::cout << "initialized: " << type_cnt << " (array size: " << stats_layer0_.size() << ") operators\n";
146 #if 1
147   std::cout << "initialized as follows:\n";
148   for (auto & i : stats_layer0_) {
149     brec_part_gaussian_sptr p = i.first->cast_to_gaussian();
150     std::cout << "l0: " << p->lambda0_ << " l1: " << p->lambda1_ << " t: " << p->theta_;
151     if (p->bright_)
152       std::cout << " bright\n";
153     else
154       std::cout << " dark\n";
155   }
156   std::cout << "--------------------------\n"
157            << " masks size: " << masks.size() << std::endl;
158   for (const auto & mask : masks) {
159     std::cout << mask << std::endl
160              << "--------------------------\n";
161   }
162 #endif
163 }
164 
165 // assumes float img with values in [0,1] range
layer0_collect_stats(vil_image_view<float> & inp,vil_image_view<float> & fg_prob_img,vil_image_view<bool> & mask)166 void brec_part_hierarchy_learner::layer0_collect_stats(vil_image_view<float>& inp, vil_image_view<float>& fg_prob_img, vil_image_view<bool>& mask)
167 {
168   for (auto & i : stats_layer0_) {
169     brec_part_instance_sptr p = i.first;
170     bsta_histogram<float> *h = i.second;
171     p->update_response_hist(inp, fg_prob_img, mask, *h);
172   }
173 }
174 
layer0_collect_stats(vil_image_view<float> & inp,vil_image_view<float> & fg_prob_img)175 void brec_part_hierarchy_learner::layer0_collect_stats(vil_image_view<float>& inp, vil_image_view<float>& fg_prob_img)
176 {
177   vil_image_view<bool> mask(inp.ni(), inp.nj());
178   mask.fill(true);
179   layer0_collect_stats(inp, fg_prob_img, mask);
180 }
181 
layer0_fit_parametric_dist()182 void brec_part_hierarchy_learner::layer0_fit_parametric_dist()
183 {
184   for (auto & i : stats_layer0_) {
185     brec_part_instance_sptr p = i.first;
186     bsta_histogram<float> *h = i.second;
187     p->fit_distribution_to_response_hist(*h);  // the computed params are saved at the instance
188   }
189 }
190 
layer0_collect_posterior_stats(vil_image_view<float> & inp,vil_image_view<float> & fg_prob_img,vil_image_view<bool> & mask,vil_image_view<float> & mean_img,vil_image_view<float> & std_dev_img)191 void brec_part_hierarchy_learner::layer0_collect_posterior_stats(vil_image_view<float>& inp,
192                                                                  vil_image_view<float>& fg_prob_img,
193                                                                  vil_image_view<bool>& mask,
194                                                                  vil_image_view<float>& mean_img,
195                                                                  vil_image_view<float>& std_dev_img)
196 {
197   for (auto & i : stats_layer0_) {
198     brec_part_instance_sptr p = i.first;
199     p->update_foreground_posterior(inp, fg_prob_img, mask, mean_img, std_dev_img);  // the computed params are saved at the instance
200   }
201 }
202 
layer0_collect_posterior_stats(vil_image_view<float> & inp,vil_image_view<float> & fg_prob_img,vil_image_view<float> & mean_img,vil_image_view<float> & std_dev_img)203 void brec_part_hierarchy_learner::layer0_collect_posterior_stats(vil_image_view<float>& inp,
204                                                                  vil_image_view<float>& fg_prob_img,
205                                                                  vil_image_view<float>& mean_img,
206                                                                  vil_image_view<float>& std_dev_img)
207 {
208   vil_image_view<bool> mask(inp.ni(), inp.nj());
209   mask.fill(true);
210   layer0_collect_posterior_stats(inp, fg_prob_img, mask, mean_img, std_dev_img);
211 }
212 
rho_more(const std::pair<brec_part_instance_sptr,bsta_histogram<float> * > & p1,const std::pair<brec_part_instance_sptr,bsta_histogram<float> * > & p2)213 bool rho_more(const std::pair<brec_part_instance_sptr, bsta_histogram<float>*>& p1,
214               const std::pair<brec_part_instance_sptr, bsta_histogram<float>*>& p2)
215 {
216   return p1.first->rho_c_f_ > p2.first->rho_c_f_;
217 }
218 
219 //: create a part hierarchy of primitive parts which are added with respect to their average rho_ (posterior ratios)
220 //  This will be used to construct layers 1 and above
layer0_rank_and_create_hierarchy(int N)221 brec_part_hierarchy_sptr brec_part_hierarchy_learner::layer0_rank_and_create_hierarchy(int N)
222 {
223   brec_part_hierarchy_sptr h = new brec_part_hierarchy();
224   std::sort(stats_layer0_.begin(), stats_layer0_.end(), rho_more);
225 
226   int cnt = (N < (int)stats_layer0_.size() ? N : (int)stats_layer0_.size());
227   for (int i = 0; i < cnt; i++) {
228     brec_part_instance_sptr p = stats_layer0_[i].first;
229     std::cout << "adding layer0 i: " << i << " type: " << p->type_ << " rho_: " << p->rho_c_f_ << std::endl;
230     brec_part_base_sptr p_0 = new brec_part_base(0, p->type_);
231     h->add_vertex(p_0);
232     h->add_dummy_primitive_instance(p);
233   }
234 
235   for (auto & i : stats_layer0_) {
236     delete i.second;
237   }
238   stats_layer0_.clear();
239 
240   return h;
241 }
242 
243 //: initialize learner to construct layer_n as pairs of layer_n-1 of the given hierarchy.
244 //  radius is used to initialize the histograms
245 //  we use 8 bins for angle in [0, 2*pi] range and 8 bins for distance in [0,radius] range
initialize_layer_n_as_pairs(const brec_part_hierarchy_sptr & h,unsigned layer_id,unsigned nclasses,float radius)246 bool brec_part_hierarchy_learner::initialize_layer_n_as_pairs(const brec_part_hierarchy_sptr& h, unsigned layer_id, unsigned nclasses, float radius)
247 {
248   if (!layer_id)
249     return false;
250 
251   h_ = h;
252   radius_ = radius;
253 
254   type_cnt_ = 0;
255   for (unsigned c = 0; c < nclasses; c++) {
256     auto* map = new class_map();
257 
258     for (auto it = h->vertices_begin(); it != h->vertices_end(); it++) {
259       if ((*it)->layer_ == layer_id-1) {
260         brec_part_base_sptr p1 = (*it);
261 
262         for (auto it2 = h->vertices_begin(); it2 != h->vertices_end(); it2++) {
263           if ((*it2)->layer_ == layer_id-1) {
264             brec_part_base_sptr p2 = (*it2);
265 
266             brec_part_instance_sptr p_n = new brec_part_instance(layer_id, type_cnt_, brec_part_instance_kind::COMPOSED, 0.0f, 0.0f, 0.0f);
267             type_cnt_++;
268             brec_hierarchy_edge_sptr e1 = new brec_hierarchy_edge(p_n->cast_to_base(), p1, true);
269             p_n->add_outgoing_edge(e1);
270 
271             brec_hierarchy_edge_sptr e2 = new brec_hierarchy_edge(p_n->cast_to_base(), p2, false);
272             p_n->add_outgoing_edge(e2);
273 
274             auto d_hist = new hist(radius_, 16);
275             auto a_hist = new hist(-vnl_math::pi, vnl_math::pi, 16);
276 
277             //sample_set_ptr d_mss = new sample_set(radius_/8.0f); //set mean shift bandwidth to the size of 2 bins
278             //sample_set_ptr a_mss = new sample_set(float(2.0f*vnl_math::pi/8.0f)); //set mean shift bandwidth to the size of 2 bins
279             auto mss = new sample_set();  // bandwidth is unimportant - because mean-shift will be applied on 1D marginalized version of this 2D data
280             d_bandwidth_ = radius_/8.0f;  //set mean shift bandwidth to the size of 2 bins
281             a_bandwidth_ = float(2.0f*vnl_math::pi/8.0f);  //set mean shift bandwidth to the size of 2 bins
282 
283             hist_ptr_pair ph(d_hist, a_hist);
284             //sample_set_ptr_pair ps(d_mss, a_mss);
285 
286             //std::pair<hist_ptr_pair, sample_set_ptr_pair> hist_pair(ph, ps);
287             std::pair<hist_ptr_pair, sample_set_ptr> hist_pair(ph, mss);
288 
289             std::pair<brec_part_instance_sptr, std::pair<hist_ptr_pair, sample_set_ptr> > pa(p_n, hist_pair);
290 
291             std::pair<unsigned, unsigned> pa_id(p1->type_, p2->type_);
292             (*map)[pa_id] = pa;
293           }
294         }
295       }
296     }
297 
298     stats_layer_n_[c] = map;
299   }
300 
301   std::cout << "initialized learner for: " << stats_layer_n_.size() << " classes\n";
302 #if 1
303   unsigned cnt = 0;
304   for (auto & it : stats_layer_n_) {
305     class_map* map = it.second;
306     std::cout << "\t class: " << cnt++ << ", initialized with " << map->size() << " pairs:\n";
307     class_map::iterator m_it;
308     for (m_it = map->begin(); m_it != map->end(); m_it++) {
309       std::pair<unsigned, unsigned> id_p = (*m_it).first;
310       std::pair<brec_part_instance_sptr, std::pair<hist_ptr_pair, sample_set_ptr> > pa = (*m_it).second;
311       std::cout << '(' << id_p.first << ", " << id_p.second << ") ";
312       for (auto eit = pa.first->out_edges_begin(); eit != pa.first->out_edges_end(); eit++) {
313         std::cout << '(' << (*eit)->target()->layer_ << ", " << (*eit)->target()->type_ << ") ";
314       }
315       std::cout << std::endl;
316     }
317   }
318 
319 #endif
320 
321   return true;
322 }
323 
324 //: a helper function to find a bin in one of 8 bins in the range [0,radius] for a given angle value.
325 //  angle values are in the range [-pi, pi], so map this range to [0,radius]
map_range_for_angle(float a,float radius)326 float map_range_for_angle(float a, float radius)
327 {
328   float val = float(vnl_math::pi)+a;  // mapped to [0,2*pi]
329   val = val/(2.0f*float(vnl_math::pi));  // mapped to [0,1]
330   val = val*radius;  // mapped to [0,radius]
331   return val;
332 }
333 
map_to_cartesian(float angle,float radius,float max_radius,float & x,float & y)334 void map_to_cartesian(float angle, float radius, float max_radius, float& x, float& y)
335 {
336   x = max_radius + (float)radius*std::cos(angle);
337   y = max_radius + (float)radius*std::sin(angle);
338 }
339 
340 //: collect joint stats to construct parts of layer with layer_id using detected parts of layer_id-1
341 //  Collect stats for a pair if they exist within radius pixels of each other
layer_n_collect_stats(const brec_part_hierarchy_detector_sptr & hd,unsigned layer_id,unsigned class_id)342 bool brec_part_hierarchy_learner::layer_n_collect_stats(const brec_part_hierarchy_detector_sptr& hd, unsigned layer_id, unsigned class_id)
343 {
344   if (!layer_id) {
345     std::cout << "In brec_part_hierarchy_learner::layer_n_collect_stats() -- layer_id is zero!!\n";
346     return false;
347   }
348 
349   brec_part_hierarchy_sptr h = hd->get_hierarchy();
350   std::vector<brec_part_instance_sptr> parts = hd->get_parts(layer_id-1);
351   Rtree_type* rtree = hd->get_tree(layer_id-1);
352 
353   class_map* map;
354   layer_n_map::iterator it;
355   it = stats_layer_n_.find(class_id);
356   if (it != stats_layer_n_.end())
357     map = (*it).second;
358   else {
359     std::cout << "WARNING: This class: " << class_id << " was not initialized!\n";
360     return false;
361   }
362 
363   std::cout << "there are " << parts.size() << " parts of layer: " << layer_id-1 << " in the detector!\n";
364   class_map::iterator qit;
365 
366   vnl_random rng;
367 
368   // go through each detected instance as central part
369   for (auto & part : parts) {
370     if (part->layer_ == layer_id-1)
371     {
372       vgl_box_2d<float> probe = part->get_probe_box(radius_);
373       std::vector<brec_part_instance_sptr> found;
374       rtree->get(probe, found);
375 
376       bsta_gaussian_sphere<float, 2> jj_dist;
377       jj_dist.set_mean(part->location()); jj_dist.set_var(1.0f);
378 
379       for (auto & kk : found) {
380         if (kk == part)
381           continue;
382         if (kk->layer_ == layer_id-1) {
383           // update stats for this pair
384           vnl_vector_fixed<float,2> cent_dif = kk->location() - part->location();
385           // calculate angle and dists
386           float a, d;
387           brec_hierarchy_edge::calculate_dist_angle(part->cast_to_instance(), cent_dif, d, a);
388 
389 #if 0  // this was 1 for the digit application
390           // make sure the samples are well-separated
391           if (d < radius_/2)
392             continue;  // discard this pair
393 #endif
394           // create a bunch of samples by assuming 1 pixel variance in the pixel locations
395           bsta_gaussian_sphere<float, 2> kk_dist;
396           kk_dist.set_mean(kk->location()); kk_dist.set_var(1.0f);
397 
398           std::vector<float> as, ds;
399           as.push_back(a); ds.push_back(d);
400           for (unsigned mmm = 0; mmm < 10; mmm++) {
401             cent_dif = kk_dist.sample(rng) - jj_dist.sample(rng);
402             brec_hierarchy_edge::calculate_dist_angle(part->cast_to_instance(), cent_dif, d, a);
403             as.push_back(a); ds.push_back(d);
404           }
405 
406           std::pair<unsigned, unsigned> qid(part->type_, kk->type_);
407           qit = map->find(qid);
408           if (qit != map->end()) {  // found the histogram
409             hist_ptr d_hist = (qit->second).second.first.first;
410             hist_ptr a_hist = (qit->second).second.first.second;
411 
412             sample_set_ptr set = (qit->second).second.second;
413 
414             //brec_part_instance_sptr layer_n_part = (qit->second).first;
415 
416             auto w1 = (float)part->cast_to_instance()->rho_c_f_;
417             auto w2 = (float)kk->cast_to_instance()->rho_c_f_;
418 
419             for (unsigned mmm = 0; mmm < as.size(); mmm++) {
420               d_hist->upcount(ds[mmm], w1*w2);
421               a_hist->upcount(as[mmm], w1*w2);
422               vnl_vector_fixed<double, 2> sample(ds[mmm], as[mmm]);
423               set->insert_sample(sample, w1*w2);
424             }
425           }
426         }
427       }
428     }
429   }
430   return true;
431 }
432 
433 //: uses the joint histograms to fit Gaussian distributions to distance for 8 orientations.
434 //  Replaces the histograms with the fitted distributions' histograms
435 //  Populate layer_n of current hierarchy of the class with parts which have models that have highest data log-likelihood
layer_n_fit_distributions(unsigned class_id,unsigned layer_id,unsigned M)436 bool brec_part_hierarchy_learner::layer_n_fit_distributions(unsigned class_id, unsigned layer_id, unsigned M)
437 {
438   class_map* map;
439   layer_n_map::iterator it;
440   it = stats_layer_n_.find(class_id);
441   if (it != stats_layer_n_.end())
442     map = (*it).second;
443   else {
444     std::cout << "WARNING: This class: " << class_id << " was not initialized!\n";
445     return false;
446   }
447 
448   // get the class hierarchy
449   brec_part_hierarchy_sptr class_h;
450   auto h_it = h_map_.find(class_id);
451   if (h_it != h_map_.end()) {
452     class_h = (*h_it).second;
453   }
454   else {  // this means we're building up layer 1 (i.e. n = 1)
455     class_h = new brec_part_hierarchy();
456     // add the primitives from h_
457     for (unsigned i = 0; i < h_->get_dummy_primitive_instances().size(); i++)
458       class_h->add_dummy_primitive_instance(h_->get_dummy_primitive_instances()[i]);
459     // add layer 0 from h_
460     for (auto v_it = h_->vertices_begin(); v_it != h_->vertices_end(); v_it++) {
461       class_h->add_vertex(*v_it);
462     }
463     for (auto e_it = h_->edges_begin(); e_it != h_->edges_end(); e_it++) {
464       class_h->add_edge_no_check(*e_it);
465     }
466 
467     h_map_[class_id] = class_h;
468   }
469 
470   //  first we need the total_weight of all the samples in all the pairs's data
471   double total_weight = 0.0;
472   class_map::iterator qit;
473   for (qit = map->begin(); qit != map->end(); qit++) {
474     sample_set_ptr set = (qit->second).second.second;
475     total_weight += set->total_weight();
476   }
477   std::cout << "total weight of all the data from all the pairs for class: " << class_id << " is: " << total_weight << std::endl;
478 
479   for (qit = map->begin(); qit != map->end(); qit++) {
480     hist_ptr d_hist = (qit->second).second.first.first;
481     hist_ptr a_hist = (qit->second).second.first.second;
482 
483     if (qit->first.first == 23 && qit->first.second == 3)
484       std::cout << "here!\n";
485 
486     //sample_set_ptr d_set = (qit->second).second.second.first;
487     //sample_set_ptr a_set = (qit->second).second.second.second;
488     sample_set_ptr set = (qit->second).second.second;
489 
490     brec_part_instance_sptr layer_n_part = (qit->second).first;
491 
492     unsigned d_nbins = d_hist->nbins();
493     float d_delta = radius_/float(d_nbins);
494 
495     // create 1D marginalized distance sample set from nD set
496     bsta_sample_set<double,1> d_set;
497     if (!bsta_sample_set_marginalize(*set, 0, d_set)) {
498       std::cout << "ERROR: cannot create 1D distance set from set!\n";
499       return false;
500     }
501     d_set.set_bandwidth(d_bandwidth_);
502 
503     bsta_sample_set<double,1> a_set;
504     if (!bsta_sample_set_marginalize(*set, 1, a_set)) {
505       std::cout << "ERROR: cannot create 1D distance set from set!\n";
506       return false;
507     }
508     a_set.set_bandwidth(a_bandwidth_);
509 
510     // run mean_shift on distance sample set
511     bsta_mean_shift<double,1> d_ms;
512     d_ms.find_modes(d_set, 0.01);
513     d_ms.trim_modes(d_set, 2*d_delta);
514     //std::cout << "In layer_n_fit_distributions(" << class_id << ") - type: " << layer_n_part->type_ << ", # of d modes: " << d_set.mode_cnt() << std::endl;
515     d_ms.merge_modes(d_set, 3, 0.01);  // merge the modes with samples less then 3
516     std::cout << "In layer_n_fit_distributions(" << class_id << ") - type: " << layer_n_part->type_ << ", # of d modes: " << d_set.mode_cnt() << std::endl;
517 
518     // get the fitted mixture using ss, ms set the assignments after trimming
519     bool ok = true;
520     bsta_mixture<bsta_num_obs<bsta_gaussian_sphere<double,1> > > d_out_dist;
521     delete d_hist;
522     d_hist = new bsta_histogram<double>(radius_, d_nbins);
523     if (!bsta_sample_set_fit_distribution<double>(d_set, d_out_dist)) {
524       std::cout << "Warning: Cannot fit a mixture to the distribution of the class: " << class_id << '\n';
525       ok = false;
526     }
527     else {
528       // now replace the histogram with the mean shift fitted one
529       for (unsigned aa = 0; aa < d_nbins; aa++) {
530         float pt = float(aa+1)*d_delta;
531         double val = d_out_dist.prob_density(pt);
532         d_hist->upcount(float(aa+1)*d_delta, val);
533       }
534     }
535 
536     (qit->second).second.first.first = d_hist;
537 
538     // run mean_shift on angle sample set
539     unsigned a_nbins = a_hist->nbins();
540     auto a_delta = float(vnl_math::twopi/a_nbins);
541 
542     bsta_mean_shift<double,1> a_ms;
543     a_ms.find_modes(a_set, 0.01f);
544     a_ms.trim_modes(a_set, 2*a_delta);
545     //std::cout << "In layer_n_fit_distributions(" << class_id << ") - type: " << layer_n_part->type_ << ", # of a modes: " << a_set.mode_cnt() << std::endl;
546     a_ms.merge_modes(a_set, 3, 0.01f);  // merge the modes with samples less then 3
547     std::cout << "In layer_n_fit_distributions(" << class_id << ") - type: " << layer_n_part->type_ << ", # of a modes: " << a_set.mode_cnt() << std::endl;
548 
549     // get the fitted mixture using ss, ms set the assignments after trimming
550     bsta_mixture<bsta_num_obs<bsta_gaussian_sphere<double,1> > > a_out_dist;
551     delete a_hist;
552     a_hist = new bsta_histogram<double>(-vnl_math::pi, vnl_math::pi, a_nbins);
553     if (!bsta_sample_set_fit_distribution<double>(a_set, a_out_dist)) {
554       std::cout << "Warning: Cannot fit a mixture to the distribution of the class: " << class_id << '\n';
555       ok = false;
556     }
557     else {
558       // now replace the histogram with the mean shift fitted one
559       for (unsigned aa = 0; aa < a_nbins; aa++) {
560         float pt = float(aa+1)*a_delta;
561         double val = a_out_dist.prob_density(pt);
562         a_hist->upcount(float(aa+1)*a_delta, val);
563       }
564     }
565 
566     (qit->second).second.first.second = a_hist;
567 
568     if (ok)
569     {
570       // find the likelihood for each of d_out_dist.num_components()*a_out_dist.num_components() possible models
571       //  update the layer_n parts of current hierarchy for the class
572       for (unsigned mi = 0; mi < d_out_dist.num_components(); mi++) {
573         for (unsigned mj = 0; mj < a_out_dist.num_components(); mj++) {
574           double w_sum = 0.0;
575           // compute ll which is the data likelihood for this pair, and w_sum/total_weight is the prior probability
576           double ll = bsta_sample_set_log_likelihood(*set, d_out_dist.distribution(mi), d_out_dist.weight(mi), a_out_dist.distribution(mj), a_out_dist.weight(mj), w_sum);
577 
578           if (ll < -1e10)
579             continue;  // skip it if it overflowed
580 
581           std::cout << "\t ll: " << ll << ' ';
582 
583           double ratio = ll;
584 
585           if (stats_layer_n_.size() > 1) {
586           // find the likelihood for each class's sample sets
587           double best_class_ll = -std::numeric_limits<double>::infinity();
588           for (auto class_it = stats_layer_n_.begin(); class_it != stats_layer_n_.end(); class_it++) {
589             if (it == class_it)
590               continue;
591 
592             std::pair<unsigned, unsigned> op_pair(qit->first.first, qit->first.second);
593             auto class_pair_it = (*class_it).second->find(op_pair);
594             if (class_pair_it == (*class_it).second->end()) {
595               std::cout << "Error: One of the classes was not initialized for the pair: " << qit->first.first << ' ' << qit->first.second << '\n';
596               return false;
597             }
598             sample_set_ptr class_set = (class_pair_it->second).second.second;
599             double class_w_sum = 0.0;
600             double class_ll = bsta_sample_set_log_likelihood(*class_set, d_out_dist.distribution(mi), d_out_dist.weight(mi), a_out_dist.distribution(mj), a_out_dist.weight(mj), class_w_sum);
601             if (class_ll < -1e10) // skip it if overflowed
602               continue;
603             if (class_ll > best_class_ll)
604               best_class_ll = class_ll;
605           }
606           if (best_class_ll < -1e10) // skip this distribution pair if overflowed for all the classes
607             continue;
608 
609           ratio -= best_class_ll;
610           std::cout << "\t best_class_ll: " << best_class_ll << " ratio: " << ratio << std::endl;
611           }
612 
613           std::cout << "\t ratio: " << ratio << std::endl;
614 
615           // prepare the part
616           brec_part_base_sptr p_n = new brec_part_base(layer_id, type_cnt_);
617           type_cnt_++;
618 
619           unsigned cpl = layer_n_part->edge_to_central_part()->target()->layer_;
620           unsigned cpt = layer_n_part->edge_to_central_part()->target()->type_;
621           brec_part_base_sptr p_n_p1 = class_h->get_node(cpl, cpt);
622 
623           unsigned spl = layer_n_part->edge_to_second_part()->target()->layer_;
624           unsigned spt = layer_n_part->edge_to_second_part()->target()->type_;
625           brec_part_base_sptr p_n_p2 = class_h->get_node(spl, spt);
626 
627           brec_hierarchy_edge_sptr e1 = new brec_hierarchy_edge(p_n, p_n_p1, true);
628           p_n->add_outgoing_edge(e1);
629 
630           brec_hierarchy_edge_sptr e2 = new brec_hierarchy_edge(p_n, p_n_p2, false);
631           p_n->add_outgoing_edge(e2);
632 
633           //p_n->prior_prob_ = w_sum/total_weight;
634           p_n->prior_prob_ = w_sum;
635           p_n->log_likelihood_ = ratio;
636           e2->set_model(d_out_dist.distribution(mi), a_out_dist.distribution(mj), d_out_dist.weight(mi)*a_out_dist.weight(mj));
637 
638           // if this is better than top M modes replace/place it in the hierarchy
639           if (class_h->layer_cnt(layer_id) < M) {  // insert directly
640             class_h->add_vertex(p_n);
641           }
642           else {  // replace one of the existing ones
643             // traverse all layer_n nodes and replace the one with worst ll
644             double min = 1e6;
645             brec_part_hierarchy::vertex_iterator v_min_it;
646             for (auto v_it = class_h->vertices_begin(); v_it != class_h->vertices_end(); v_it++) {
647               if ((*v_it)->layer_ != layer_id)
648                 continue;
649 
650               if (min > (*v_it)->log_likelihood_) {
651                 v_min_it = v_it;
652                 min = (*v_it)->log_likelihood_;
653               }
654             }
655             if (min < ratio) {
656               if (!class_h->remove_vertex(*v_min_it)) {
657                 std::cout << "ERROR: brec_part_hierarchy_learner::layer_n_fit_distributions() -- cannot delete vertex from hierarchy!\n";
658                 return false;
659               }
660               class_h->add_vertex(p_n);
661             }
662           }
663         }
664       }
665     }  // if ok (both histograms have been fitted)
666   }
667 
668   // now fix the hierarchy with the added nodes
669   for (auto v_it = class_h->vertices_begin(); v_it != class_h->vertices_end(); v_it++) {
670     if ((*v_it)->layer_ != layer_id)
671       continue;
672     brec_part_base_sptr p = (*v_it);
673     // add the edges of this part to the hierarchy
674     for (auto e_it = p->out_edges_begin(); e_it != p->out_edges_end(); e_it++) {
675       class_h->add_edge_no_check((*e_it));
676       (*e_it)->target()->add_incoming_edge((*e_it));
677     }
678   }
679 
680   return true;
681 }
682 
print_layer0()683 void brec_part_hierarchy_learner::print_layer0()
684 {
685   for (auto & i : stats_layer0_) {
686     brec_part_instance_sptr pi = i.first;
687     bsta_histogram<float> *h = i.second;
688     if (pi->kind_ == brec_part_instance_kind::GAUSSIAN) {
689       brec_part_gaussian_sptr p = pi->cast_to_gaussian();
690       if (p->bright_)
691         std::cout << "--- lambda0 " << p->lambda0_ << " --- lambda1 " << p->lambda1_ << " --- theta " << p->theta_ << " --- bright ---\n";
692       else
693         std::cout << "--- lambda0 " << p->lambda0_ << " --- lambda1 " << p->lambda1_ << " --- theta " << p->theta_ << " --- dark ---\n";
694     }
695     std::cout << "----- foreground hist ----------\n";
696     h->print();
697     std::cout << "-------------------------------------------\n";
698   }
699 }
700 
print_to_m_file_layer0(const std::string & file_name)701 void brec_part_hierarchy_learner::print_to_m_file_layer0(const std::string& file_name)
702 {
703   std::ofstream ofs(file_name.c_str());
704   ofs << "% dump histograms\n";
705 
706   for (unsigned i = 0; i < stats_layer0_.size(); i++) {
707     brec_part_instance_sptr pi = stats_layer0_[i].first;
708     bsta_histogram<float> *h = stats_layer0_[i].second;
709 
710     if (i%(n_*n_) == 0)
711       ofs << "figure;\n";
712 
713     ofs << "subplot(1," << n_*n_ << ", " << (i%(n_*n_))+1 << "), ";
714 
715     h->print_to_m(ofs);
716     if (pi->kind_ == brec_part_instance_kind::GAUSSIAN) {
717       ofs << "title('";
718       brec_part_gaussian_sptr p = pi->cast_to_gaussian();
719 
720       if (p->bright_)
721         ofs << "l0: " << p->lambda0_ << " l1: " << p->lambda1_ << " t: " << p->theta_ << " b foreg');\n";
722       else
723         ofs << "l0: " << p->lambda0_ << " l1: " << p->lambda1_ << " t: " << p->theta_ << " d foreg');\n";
724     }
725     //ofs << "axis([-7.0 1.0 0.0 1.0]);\n";
726   }
727 
728   ofs.close();
729 }
730 
print_to_m_file_layer_n(const std::string & file_name,unsigned class_id,bool print_set)731 void brec_part_hierarchy_learner::print_to_m_file_layer_n(const std::string& file_name, unsigned class_id, bool print_set)
732 {
733   std::ofstream ofs(file_name.c_str());
734   ofs << "% dump histograms\n";
735 
736   auto it = stats_layer_n_.find(class_id);
737   if (it == stats_layer_n_.end()) {
738     std::cout << "Error: Cannot find stats for class: " << class_id << '\n';
739     ofs.close();
740     return;
741   }
742 
743   class_map* map = it->second;
744 
745   class_map::iterator m_it;
746   //unsigned i = 0;
747   for (m_it = map->begin(); m_it != map->end(); m_it++)
748   {
749     std::pair<unsigned, unsigned> id_p = (*m_it).first;
750 
751     hist_ptr d_hist = (*m_it).second.second.first.first;
752     hist_ptr a_hist = (*m_it).second.second.first.second;
753 
754     //sample_set_ptr d_set = (*m_it).second.second.second.first;
755     //sample_set_ptr a_set = (*m_it).second.second.second.second;
756     sample_set_ptr set = (*m_it).second.second.second;
757 
758     brec_part_instance_sptr pa = (*m_it).second.first;
759 
760     brec_part_base_sptr cp = pa->edge_to_central_part()->target();
761     brec_part_base_sptr sp = pa->edge_to_second_part()->target();
762     std::string cp_sid = "";
763     std::string sp_sid = "";
764     brec_part_instance_sptr cpi = h_->get_node_instance(cp->layer_, cp->type_);
765     brec_part_instance_sptr spi = h_->get_node_instance(sp->layer_, sp->type_);
766     if (!cpi) {
767       //std::cout << "instance could not be found in the hierarchy!\n";
768       std::stringstream ss; ss << cp->layer_ << ' ' << cp->type_;
769       cp_sid = ss.str();
770     }
771     else if (cpi->kind_ == brec_part_instance_kind::GAUSSIAN)
772       cp_sid = cpi->cast_to_gaussian()->string_identifier();
773     if (!spi) {
774       //std::cout << "instance could not be found in the hierarchy!\n";
775       std::stringstream ss; ss << sp->layer_ << ' ' << sp->type_;
776       sp_sid = ss.str();
777     }
778     else if (spi->kind_ == brec_part_instance_kind::GAUSSIAN)
779       sp_sid = spi->cast_to_gaussian()->string_identifier();
780 
781     ofs << "h = figure;\n";
782     ofs << "set(h, 'Name', 'Class: " << class_id << " pair: (" << id_p.first << ": " << cp_sid << ", " << id_p.second << ": " << sp_sid << ") ');\n";
783 
784     //ofs << "subplot(1," << n_ << ", " << i%n_+1 << "), ";
785     //i++;
786 
787     //if (print_set)
788     //  ofs << "subplot(2,2,1), ";
789     //else
790     if (!print_set) {
791       ofs << "subplot(1,2,1), ";
792       d_hist->print_to_m(ofs);
793       ofs << "xlabel('distance in range [0,"<< radius_ << "]');\n";
794     }
795 
796     //ofs << "AXIS([0 " << d_hist->nbins()+1 << " 0.0 0.5]);\n";
797 
798     //if (print_set)
799     //  ofs << "subplot(2,2,2), ";
800 
801     if (!print_set) {
802       ofs << "subplot(1,2,2), ";
803       a_hist->print_to_m(ofs);
804       ofs << "xlabel('angle in range [0,2*pi]');\n";
805     }
806     //ofs << "AXIS([0 " << a_hist->nbins()+1 << " 0.0 0.5]);\n";
807 
808     if (print_set)
809     {
810       // create 1D marginalized distance sample set from nD set
811       bsta_sample_set<double,1> d_set;
812       if (!bsta_sample_set_marginalize(*set, 0, d_set)) {
813         std::cout << "ERROR: cannot create 1D distance set from set!\n";
814         ofs.close();
815         return;
816       }
817 
818       d_set.set_bandwidth(d_bandwidth_);
819       // run mean_shift on distance sample set
820       bsta_mean_shift<double,1> d_ms;
821       d_ms.find_modes(d_set, 0.01);
822       d_ms.trim_modes(d_set, 2*d_hist->delta());
823       d_ms.merge_modes(d_set, 3, 0.01);  // merge the modes with samples less then 3
824 
825       bsta_sample_set<double,1> a_set;
826       if (!bsta_sample_set_marginalize(*set, 1, a_set)) {
827         std::cout << "ERROR: cannot create 1D distance set from set!\n";
828         ofs.close();
829         return;
830       }
831 
832       a_set.set_bandwidth(a_bandwidth_);
833       // run mean_shift on distance sample set
834       bsta_mean_shift<double,1> a_ms;
835       a_ms.find_modes(a_set, 0.01);
836       a_ms.trim_modes(a_set, 2*a_hist->delta());
837       a_ms.merge_modes(a_set, 3, 0.01);  // merge the modes with samples less then 3
838 
839       ofs << "subplot(1,2,1), ";
840       //bsta_sample_set_print_to_m(*ss, ofs);
841       bsta_sample_set_dist_print_to_m(d_set, ofs);
842       ofs << "xlabel('distance in range [0,"<< radius_ << "]');\n";
843 
844       ofs << "subplot(1,2,2), ";
845       //bsta_sample_set_print_to_m(*ss, ofs);
846       bsta_sample_set_dist_print_to_m(a_set, ofs);
847       ofs << "xlabel('angle in range [0,2*pi]');\n";
848     }
849 
850     // ask for a character input after each figure
851     ofs << "sscanf( input('','s'), '%c' );\n";
852   }
853 
854   ofs.close();
855 }
856 
print_to_m_file_layer0_fitted_dists(const std::string & file_name)857 void brec_part_hierarchy_learner::print_to_m_file_layer0_fitted_dists(const std::string& file_name)
858 {
859   std::ofstream ofs(file_name.c_str());
860   ofs << "% dump histograms of fitted distributions\n";
861 
862   for (unsigned i = 0; i < stats_layer0_.size(); i++) {
863     brec_part_instance_sptr pi = stats_layer0_[i].first;
864     bsta_histogram<float> *h = stats_layer0_[i].second;
865 
866     std::vector<float> x;
867     for (float val = h->min(); val <= h->max(); val += h->delta()) {
868       x.push_back(val);
869     }
870 
871     if (i%(n_*n_) == 0)
872       ofs << "figure;\n";
873 
874     ofs << "subplot(1," << n_*n_ << ", " << (i%(n_*n_))+1 << "), "
875         << "x = [" << x[0];
876     for (unsigned jj = 1; jj < x.size(); jj++)
877       ofs << ", " << x[jj];
878     ofs << "];\n";
879 
880     if (pi->kind_ == brec_part_instance_kind::GAUSSIAN) {
881       brec_part_gaussian_sptr p = pi->cast_to_gaussian();
882 
883       if (p->fitted_weibull_)
884       {
885         bsta_weibull<float> pdfg(p->lambda_, p->k_);
886 
887         ofs << "y = [" << pdfg.prob_density(x[0]);
888         for (unsigned jj = 1; jj < x.size(); jj++)
889           ofs << ", " << pdfg.prob_density(x[jj]);
890         ofs << "];\n"
891             << "bar(x,y,'r');\n"
892             << "title('";
893         if (p->bright_)
894           ofs << "l0: " << p->lambda0_ << " l1: " << p->lambda1_ << " t: " << p->theta_ << " b weibull');\n";
895         else
896           ofs << "l0: " << p->lambda0_ << " l1: " << p->lambda1_ << " t: " << p->theta_ << " d weibull');\n";
897       }
898       else {
899         std::cout << "WARNING: no fitted foreground response model for this operator! Cannot print to m file.\n";
900       }
901     }
902     //ofs << "axis([-7.0 1.0 0.0 1.0]);\n";
903   }
904 
905   ofs.close();
906 }
907 
908 //: Binary io, NOT IMPLEMENTED, signatures defined to use brec_part_hierarchy as a brdb_value
vsl_b_write(vsl_b_ostream &,brec_part_hierarchy_learner const &)909 void vsl_b_write(vsl_b_ostream & /*os*/, brec_part_hierarchy_learner const & /*ph*/)
910 {
911   std::cerr << "vsl_b_write() -- Binary io, NOT IMPLEMENTED, signatures defined to use brec_part_hierarchy_learner as a brdb_value\n";
912   return;
913 }
914 
915 //: Binary io, NOT IMPLEMENTED, signatures defined to use brec_part_hierarchy as a brdb_value
vsl_b_read(vsl_b_istream &,brec_part_hierarchy_learner &)916 void vsl_b_read(vsl_b_istream & /*is*/, brec_part_hierarchy_learner & /*ph*/)
917 {
918   std::cerr << "vsl_b_read() -- Binary io, NOT IMPLEMENTED, signatures defined to use brec_part_hierarchy_learner as a brdb_value\n";
919   return;
920 }
921 
vsl_b_read(vsl_b_istream & is,brec_part_hierarchy_learner * ph)922 void vsl_b_read(vsl_b_istream& is, brec_part_hierarchy_learner* ph)
923 {
924   delete ph;
925   bool not_null_ptr;
926   vsl_b_read(is, not_null_ptr);
927   if (not_null_ptr)
928   {
929     ph = new brec_part_hierarchy_learner();
930     vsl_b_read(is, *ph);
931   }
932   else
933     ph = nullptr;
934 }
935 
vsl_b_write(vsl_b_ostream & os,const brec_part_hierarchy_learner * & ph)936 void vsl_b_write(vsl_b_ostream& os, const brec_part_hierarchy_learner* &ph)
937 {
938   if (ph==nullptr)
939   {
940     vsl_b_write(os, false); // Indicate null pointer stored
941   }
942   else
943   {
944     vsl_b_write(os,true); // Indicate non-null pointer stored
945     vsl_b_write(os,*ph);
946   }
947 }
948