1 // This is brl/bbas/bsta/algo/bsta_gaussian_updater.h
2 #ifndef bsta_gaussian_updater_h_
3 #define bsta_gaussian_updater_h_
4 //:
5 // \file
6 // \brief Iterative updating of Gaussians
7 // \author Matt Leotta (mleotta@lems.brown.edu)
8 // \date February 22, 2006
9 //
10 // \verbatim
11 // Modifications
12 // Jun 18, 2008 - Matt Leotta -- Adjusted such that min_var is a hard minimum
13 // instead of a minimum in the limit
14 // \endverbatim
15
16 #include <iostream>
17 #include <algorithm>
18 #include <bsta/bsta_gaussian_sphere.h>
19 #include <bsta/bsta_gaussian_indep.h>
20 #include <bsta/bsta_gaussian_full.h>
21 #include <bsta/bsta_mixture.h>
22 #include <bsta/bsta_attributes.h>
23 #ifdef _MSC_VER
24 # include <vcl_msvc_warnings.h>
25 #endif
26
27
28 //: Update the statistics given a 1D Gaussian distribution and a learning rate
29 // \note if rho = 1/(num observations) then this just an online cumulative average
30 template <class T>
bsta_update_gaussian(bsta_gaussian_sphere<T,1> & gaussian,T rho,const T & sample)31 void bsta_update_gaussian(bsta_gaussian_sphere<T,1>& gaussian, T rho,
32 const T& sample )
33 {
34 // the complement of rho (i.e. rho+rho_comp=1.0)
35 T rho_comp = 1.0f - rho;
36 // compute the updated mean
37 const T& old_mean = gaussian.mean();
38
39 T diff = sample - old_mean;
40 T new_var = rho_comp * gaussian.var();
41 new_var += (rho * rho_comp) * diff*diff;
42
43 gaussian.set_var(new_var);
44 gaussian.set_mean((old_mean) + (rho * diff));
45 }
46
47
48 //: Update the statistics given a Gaussian distribution and a learning rate
49 // \note if rho = 1/(num observations) then this just an online cumulative average
50 template <class T, unsigned n>
bsta_update_gaussian(bsta_gaussian_sphere<T,n> & gaussian,T rho,const vnl_vector_fixed<T,n> & sample)51 void bsta_update_gaussian(bsta_gaussian_sphere<T,n>& gaussian, T rho,
52 const vnl_vector_fixed<T,n>& sample )
53 {
54 // the complement of rho (i.e. rho+rho_comp=1.0)
55 T rho_comp = 1.0f - rho;
56 // compute the updated mean
57 const vnl_vector_fixed<T,n>& old_mean = gaussian.mean();
58
59 vnl_vector_fixed<T,n> diff(sample - old_mean);
60 T new_var = rho_comp * gaussian.var();
61 new_var += (rho * rho_comp) * dot_product(diff,diff);
62
63 gaussian.set_var(new_var);
64 gaussian.set_mean((old_mean) + (rho * diff));
65 }
66
67
68 //: Update the statistics given a Gaussian distribution and a learning rate
69 // \note if rho = 1/(num observations) then this just an online cumulative average
70 template <class T, unsigned n>
bsta_update_gaussian(bsta_gaussian_indep<T,n> & gaussian,T rho,const vnl_vector_fixed<T,n> & sample)71 void bsta_update_gaussian(bsta_gaussian_indep<T,n>& gaussian, T rho,
72 const vnl_vector_fixed<T,n>& sample )
73 {
74 // the complement of rho (i.e. rho+rho_comp=1.0)
75 T rho_comp = 1.0f - rho;
76 // compute the updated mean
77 const vnl_vector_fixed<T,n>& old_mean = gaussian.mean();
78
79 vnl_vector_fixed<T,n> diff(sample - old_mean);
80
81 vnl_vector_fixed<T,n> new_covar(rho_comp * gaussian.diag_covar());
82 new_covar += (rho * rho_comp) * element_product(diff,diff);
83
84 gaussian.set_covar(new_covar);
85 gaussian.set_mean((old_mean) + (rho * diff));
86 }
87
88
89 //: Update the statistics given a Gaussian distribution and a learning rate
90 // \note if rho = 1/(num observations) then this just an online cumulative average
91 template <class T, unsigned n>
bsta_update_gaussian(bsta_gaussian_full<T,n> & gaussian,T rho,const vnl_vector_fixed<T,n> & sample)92 void bsta_update_gaussian(bsta_gaussian_full<T,n>& gaussian, T rho,
93 const vnl_vector_fixed<T,n>& sample )
94 {
95 // the complement of rho (i.e. rho+rho_comp=1.0)
96 T rho_comp = 1.0f - rho;
97 // compute the updated mean
98 const vnl_vector_fixed<T,n>& old_mean = gaussian.mean();
99
100 vnl_vector_fixed<T,n> diff(sample - old_mean);
101
102 vnl_matrix_fixed<T,n,n> new_covar(rho_comp * gaussian.covar());
103 new_covar += (rho * rho_comp) * outer_product(diff,diff);
104
105 gaussian.set_covar(new_covar);
106 gaussian.set_mean((old_mean) + (rho * diff));
107 }
108
109
110 //-----------------------------------------------------------------------------
111 // The following versions allow for a lower limit on variances.
112 // If the same sample is observed repeatedly, the variances will
113 // converge to the minimum value parameter rather than zero.
114
115
116 template <class T>
element_max(const T & a,const T & b)117 inline T element_max(const T& a, const T& b)
118 {
119 return std::max(a,b);
120 }
121
122
123 //: element-wise minimum of vector.
124 template <class T, unsigned n>
element_max(const vnl_vector_fixed<T,n> & a_vector,const T & b)125 vnl_vector_fixed<T,n> element_max(const vnl_vector_fixed<T,n>& a_vector,
126 const T& b)
127 {
128 vnl_vector_fixed<T,n> min_vector;
129 T* r = min_vector.data_block();
130 const T* a = a_vector.data_block();
131 for (unsigned i=0; i<n; ++i, ++r, ++a)
132 *r = std::max(*a,b);
133 return min_vector;
134 }
135
136
137 //: element-wise minimum of vector.
138 template <class T, unsigned n>
element_max(const vnl_vector_fixed<T,n> & a_vector,const vnl_vector_fixed<T,n> & b_vector)139 vnl_vector_fixed<T,n> element_max(const vnl_vector_fixed<T,n>& a_vector,
140 const vnl_vector_fixed<T,n>& b_vector)
141 {
142 vnl_vector_fixed<T,n> min_vector;
143 T* r = min_vector.data_block();
144 const T* a = a_vector.data_block();
145 const T* b = b_vector.data_block();
146 for (unsigned i=0; i<n; ++i, ++r, ++a, ++b)
147 *r = std::max(*a,*b);
148 return min_vector;
149 }
150
151
152 //: element-wise minimum on the matrix diagonal.
153 template <class T, unsigned n>
element_max(const vnl_matrix_fixed<T,n,n> & a_matrix,const T & b)154 vnl_matrix_fixed<T,n,n> element_max(const vnl_matrix_fixed<T,n,n>& a_matrix,
155 const T& b)
156 {
157 vnl_matrix_fixed<T,n,n> min_matrix(a_matrix);
158 T* r = min_matrix.data_block();
159 const T* a = a_matrix.data_block();
160 const unsigned step = n+1;
161 for (unsigned i=0; i<n; ++i, r+=step, a+=step)
162 *r = std::max(*a,b);
163 return min_matrix;
164 }
165
166
167 //: element-wise minimum of matrix.
168 template <class T, unsigned n>
element_max(const vnl_matrix_fixed<T,n,n> & a_matrix,const vnl_matrix_fixed<T,n,n> & b_matrix)169 vnl_matrix_fixed<T,n,n> element_max(const vnl_matrix_fixed<T,n,n>& a_matrix,
170 const vnl_matrix_fixed<T,n,n>& b_matrix)
171 {
172 vnl_matrix_fixed<T,n,n> min_matrix;
173 T* r = min_matrix.data_block();
174 const T* a = a_matrix.data_block();
175 const T* b = b_matrix.data_block();
176 const unsigned num_elements = n*n;
177 for (unsigned i=0; i<num_elements; ++i, ++r, ++a, ++b)
178 *r = std::max(*a,*b);
179 return min_matrix;
180 }
181
182
183 //: Update the statistics given a Gaussian distribution and a learning rate
184 // \param min_covar forces the covariance to stay above this limit
185 // \note if rho = 1/(num observations) then this just an online cumulative average
186 template <class gauss_>
bsta_update_gaussian(gauss_ & gaussian,typename gauss_::math_type rho,const typename gauss_::vector_type & sample,const typename gauss_::covar_type & min_covar)187 inline void bsta_update_gaussian(gauss_& gaussian,
188 typename gauss_::math_type rho,
189 const typename gauss_::vector_type& sample,
190 const typename gauss_::covar_type& min_covar)
191 {
192 bsta_update_gaussian(gaussian, rho, sample);
193 gaussian.set_covar(element_max(gaussian.covar(),min_covar));
194 }
195
196
197 //: Update the statistics given a Gaussian distribution and a learning rate
198 // \param min_var forces all the variances to stay above this limit
199 // \note if rho = 1/(num observations) then this just an online cumulative average
200 template <class T, unsigned n>
bsta_update_gaussian(bsta_gaussian_indep<T,n> & gaussian,T rho,const vnl_vector_fixed<T,n> & sample,T min_var)201 inline void bsta_update_gaussian(bsta_gaussian_indep<T,n>& gaussian, T rho,
202 const vnl_vector_fixed<T,n>& sample,
203 T min_var)
204 {
205 bsta_update_gaussian(gaussian, rho, sample);
206 gaussian.set_covar(element_max(gaussian.covar(),min_var));
207 }
208
209
210 //: Update the statistics given a Gaussian distribution and a learning rate
211 // \param min_var forces the diagonal covariance to stay above this limit
212 // \note if rho = 1/(num observations) then this just an online cumulative average
213 template <class T, unsigned n>
bsta_update_gaussian(bsta_gaussian_full<T,n> & gaussian,T rho,const vnl_vector_fixed<T,n> & sample,T min_var)214 inline void bsta_update_gaussian(bsta_gaussian_full<T,n>& gaussian, T rho,
215 const vnl_vector_fixed<T,n>& sample,
216 T min_var)
217 {
218 bsta_update_gaussian(gaussian, rho, sample);
219 gaussian.set_covar(element_max(gaussian.covar(),min_var));
220 }
221
222
223 //-----------------------------------------------------------------------------
224
225
226 //: An updater for statistically updating Gaussian distributions
227 template <class gauss_>
228 class bsta_gaussian_updater
229 {
230 typedef bsta_num_obs<gauss_> obs_gauss_;
231 typedef typename gauss_::math_type T;
232 typedef vnl_vector_fixed<T,gauss_::dimension> vector_;
233 public:
234
235 //: for compatibility with vpdl/vpdt
236 typedef typename gauss_::field_type field_type;
237 typedef gauss_ distribution_type;
238
239
240 //: The main function
241 // make the appropriate type casts and call a helper function
operator()242 void operator() ( obs_gauss_& d, const vector_& sample ) const
243 {
244 d.num_observations += T(1);
245 bsta_update_gaussian(d, T(1)/d.num_observations, sample);
246 }
247 };
248
249
250 //: An updater for updating Gaussian distributions with a moving window
251 // When the number of samples exceeds the window size the most recent
252 // samples contribute more toward the distribution.
253 template <class gauss_>
254 class bsta_gaussian_window_updater
255 {
256 typedef bsta_num_obs<gauss_> obs_gauss_;
257 typedef typename gauss_::math_type T;
258 typedef vnl_vector_fixed<T,gauss_::dimension> vector_;
259 public:
260
261 //: for compatibility with vpdl/vpdt
262 typedef typename gauss_::field_type field_type;
263 typedef gauss_ distribution_type;
264
265
266 //: Constructor
bsta_gaussian_window_updater(unsigned int ws)267 bsta_gaussian_window_updater(unsigned int ws) : window_size(ws) {}
268
269 //: The main function
270 // make the appropriate type casts and call a helper function
operator()271 void operator() ( obs_gauss_& d, const vector_& sample) const
272 {
273 if (d.num_observations < window_size)
274 d.num_observations += T(1);
275 bsta_update_gaussian(d, T(1)/d.num_observations, sample);
276 }
277
278 unsigned int window_size;
279 };
280
281
282 #endif // bsta_gaussian_updater_h_
283