1 #ifndef sdet_mrf_site_bp_h_ 2 #define sdet_mrf_site_bp_h_ 3 //: 4 // \file 5 // \brief A class for representing a site node in an MRF 6 // \author J.L. Mundy 7 // \date 26 March 2011 8 // 9 // Stores two sets of buffer arrays,one for messages received at the last step 10 // and one for messages coming in during the current iteration 11 // the MRF has a 4-neighborhood (u, l, r, d) with the neighbor index 12 // in the order (0 1 2 3). 13 // 14 // The data cost for each site is D(fp) = lambda_*(fp-x)^2, where x is the 15 // observed data and fp is a site label. 16 // 17 // D(fp) is set to min(lambda_*(fp-x)^2, truncation_cost_); 18 // 19 // Each site stores a pair of label buffers for each of the neighbors 20 // One buffer in the pair stores the message received on the last 21 // iteration (p), the other receives the current messages (c). On each 22 // iteration, the buffers are swapped. A pair of buffers is allocated for 23 // each of the (u, l, r, d) neighbors in the 4-connected neighborhood 24 // as shown below. 25 // 26 // (c p) 27 // u 28 // (c,p) l x r (c,p) 29 // d 30 // (c,p) 31 // 32 #include <iostream> 33 #include <vector> 34 #include <vbl/vbl_ref_count.h> 35 #ifdef _MSC_VER 36 # include <vcl_msvc_warnings.h> 37 #endif 38 class sdet_mrf_site_bp : public vbl_ref_count 39 { 40 public: 41 42 sdet_mrf_site_bp(unsigned n_labels, float lambda, float truncation_cost); switch_buffers()43 void switch_buffers() { prior_ = 1-prior_; } prior()44 int prior() const { return prior_; } current()45 int current() const { return 1-prior_; } 46 //: set the observed label set_label(float obs_label)47 void set_label(float obs_label) { obs_label_ = obs_label; } 48 49 // === cost functions === 50 51 //:data cost due to observed continuous label value 52 float D(unsigned fp); 53 54 //:sum over stored prior messages, except the message from neighbor nq 55 float M(unsigned nq, unsigned fp); 56 57 //:total of D and M h(unsigned nq,unsigned fp)58 float h(unsigned nq, unsigned fp) { return D(fp) + M(nq, fp); } 59 60 //:belief, sum of data cost and sum of all four prior messages 61 float b(unsigned fp); 62 63 //:the most probable label, label with minimum belief 64 unsigned believed_label(); 65 66 //set the current message from neighbor nq 67 void set_cur_message(unsigned nq, unsigned fp, float msg); 68 69 //:the current message value cur_message(unsigned nq,unsigned fp)70 float cur_message(unsigned nq, unsigned fp) const { return msg_[1-prior_][nq][fp]; } 71 72 //:the prior message value prior_message(unsigned nq,unsigned fp)73 float prior_message(unsigned nq, unsigned fp) const { return msg_[prior_][nq][fp]; } 74 75 //:entire prior message 76 std::vector<float> prior_message(unsigned nq); 77 78 //:set prior message 79 void set_prior_message(unsigned nq, std::vector<float>const& msg); 80 81 //:clear messages 82 void clear(); 83 84 //: print the value of the messages held in the prior queue. 85 void print_prior_messages(); 86 void print_current_messages(); 87 void print_belief_vector(); 88 89 protected: 90 // parameters for computing message values 91 float lambda_; 92 float truncation_cost_; 93 94 //: this index is toggled to swap the buffers 95 int prior_; 96 97 //: typically 256 98 unsigned n_labels_; 99 100 //: currently 4, but might change in the future 101 unsigned n_ngbh_; 102 103 // define the neighbor index 104 // 0 105 // 1 x 2 106 // 3 107 108 //: a set of 2 message buffers, prior and current, one for each neighbor 109 // (p, c) n_ngbh_ n_labels_ 110 std::vector< std::vector<std::vector<short> > > msg_; 111 // cut down storage using short (for byte images should be adequate) 112 113 //: the label represented by the data 114 float obs_label_; 115 }; 116 117 #include <sdet/sdet_mrf_site_bp_sptr.h> 118 #endif // sdet_mrf_site_bp_h_ 119