1 #ifndef sdet_mrf_site_bp_h_
2 #define sdet_mrf_site_bp_h_
3 //:
4 // \file
5 // \brief  A class for representing a site node in an MRF
6 // \author J.L. Mundy
7 // \date   26 March 2011
8 //
9 // Stores two sets of buffer arrays,one for messages received at the last step
10 //  and one for messages coming in during the current iteration
11 // the MRF has a 4-neighborhood (u, l, r, d) with the neighbor index
12 // in the order (0 1 2 3).
13 //
14 // The data cost for each site is D(fp) = lambda_*(fp-x)^2, where x is the
15 // observed data and fp is a site label.
16 //
17 // D(fp) is set to min(lambda_*(fp-x)^2, truncation_cost_);
18 //
19 //  Each site stores a pair of label buffers for each of the neighbors
20 //  One buffer in the pair stores the message received on the last
21 //  iteration (p), the other receives the current messages (c). On each
22 //  iteration, the buffers are swapped. A pair of buffers is allocated for
23 //  each of the (u, l, r, d) neighbors in the 4-connected neighborhood
24 //  as shown below.
25 //
26 //          (c p)
27 //            u
28 //    (c,p) l x r (c,p)
29 //            d
30 //          (c,p)
31 //
32 #include <iostream>
33 #include <vector>
34 #include <vbl/vbl_ref_count.h>
35 #ifdef _MSC_VER
36 #  include <vcl_msvc_warnings.h>
37 #endif
38 class sdet_mrf_site_bp : public vbl_ref_count
39 {
40  public:
41 
42   sdet_mrf_site_bp(unsigned n_labels, float lambda, float truncation_cost);
switch_buffers()43   void switch_buffers() { prior_ = 1-prior_; }
prior()44   int prior() const { return prior_; }
current()45   int current() const { return 1-prior_; }
46   //: set the observed label
set_label(float obs_label)47   void set_label(float obs_label) { obs_label_ = obs_label; }
48 
49   // === cost functions ===
50 
51   //:data cost due to observed continuous label value
52   float D(unsigned fp);
53 
54   //:sum over stored prior messages, except the message from neighbor nq
55   float M(unsigned nq, unsigned fp);
56 
57   //:total of D and M
h(unsigned nq,unsigned fp)58   float h(unsigned nq, unsigned fp) { return D(fp) + M(nq, fp); }
59 
60   //:belief, sum of data cost and sum of all four prior messages
61   float b(unsigned fp);
62 
63   //:the most probable label, label with minimum belief
64   unsigned believed_label();
65 
66   //set the current message from neighbor nq
67   void set_cur_message(unsigned nq, unsigned fp, float msg);
68 
69   //:the current message value
cur_message(unsigned nq,unsigned fp)70   float cur_message(unsigned nq, unsigned fp) const { return msg_[1-prior_][nq][fp]; }
71 
72   //:the prior message value
prior_message(unsigned nq,unsigned fp)73   float prior_message(unsigned nq, unsigned fp) const { return msg_[prior_][nq][fp]; }
74 
75   //:entire prior message
76   std::vector<float> prior_message(unsigned nq);
77 
78   //:set prior message
79   void set_prior_message(unsigned nq, std::vector<float>const& msg);
80 
81   //:clear messages
82   void clear();
83 
84   //: print the value of the messages held in the prior queue.
85   void print_prior_messages();
86   void print_current_messages();
87   void print_belief_vector();
88 
89  protected:
90   // parameters for computing message values
91   float lambda_;
92   float truncation_cost_;
93 
94   //: this index is toggled to swap the buffers
95   int prior_;
96 
97   //: typically 256
98   unsigned n_labels_;
99 
100   //: currently 4, but might change in the future
101   unsigned n_ngbh_;
102 
103   // define the neighbor index
104   //     0
105   //   1 x 2
106   //     3
107 
108   //: a set of 2 message buffers, prior and current, one for each neighbor
109   // (p, c)     n_ngbh_    n_labels_
110   std::vector< std::vector<std::vector<short> > > msg_;
111   // cut down storage using short (for byte images should be adequate)
112 
113   //: the label represented by the data
114   float obs_label_;
115 };
116 
117 #include <sdet/sdet_mrf_site_bp_sptr.h>
118 #endif // sdet_mrf_site_bp_h_
119