1 // Copyright (C) 2012 Davis E. King (davis@dlib.net) 2 // License: Boost Software License See LICENSE.txt for the full license. 3 #ifndef DLIB_KALMAN_FiLTER_Hh_ 4 #define DLIB_KALMAN_FiLTER_Hh_ 5 6 #include "kalman_filter_abstract.h" 7 #include "../matrix.h" 8 #include "../geometry.h" 9 10 namespace dlib 11 { 12 13 // ---------------------------------------------------------------------------------------- 14 15 template < 16 long states, 17 long measurements 18 > 19 class kalman_filter 20 { 21 public: 22 kalman_filter()23 kalman_filter() 24 { 25 H = 0; 26 A = 0; 27 Q = 0; 28 R = 0; 29 x = 0; 30 xb = 0; 31 P = identity_matrix<double>(states); 32 got_first_meas = false; 33 } 34 set_observation_model(const matrix<double,measurements,states> & H_)35 void set_observation_model ( const matrix<double,measurements,states>& H_) { H = H_; } set_transition_model(const matrix<double,states,states> & A_)36 void set_transition_model ( const matrix<double,states,states>& A_) { A = A_; } set_process_noise(const matrix<double,states,states> & Q_)37 void set_process_noise ( const matrix<double,states,states>& Q_) { Q = Q_; } set_measurement_noise(const matrix<double,measurements,measurements> & R_)38 void set_measurement_noise ( const matrix<double,measurements,measurements>& R_) { R = R_; } set_estimation_error_covariance(const matrix<double,states,states> & P_)39 void set_estimation_error_covariance( const matrix<double,states,states>& P_) { P = P_; } set_state(const matrix<double,states,1> & xb_)40 void set_state ( const matrix<double,states,1>& xb_) 41 { 42 xb = xb_; 43 if (!got_first_meas) 44 { 45 x = xb_; 46 got_first_meas = true; 47 } 48 } 49 get_observation_model()50 const matrix<double,measurements,states>& get_observation_model ( 51 ) const { return H; } 52 get_transition_model()53 const matrix<double,states,states>& get_transition_model ( 54 ) const { return A; } 55 get_process_noise()56 const matrix<double,states,states>& get_process_noise ( 57 ) const { return Q; } 58 get_measurement_noise()59 const matrix<double,measurements,measurements>& get_measurement_noise ( 60 ) const { return R; } 61 update()62 void update ( 63 ) 64 { 65 // propagate estimation error covariance forward 66 P = A*P*trans(A) + Q; 67 68 // propagate state forward 69 x = xb; 70 xb = A*x; 71 } 72 update(const matrix<double,measurements,1> & z)73 void update (const matrix<double,measurements,1>& z) 74 { 75 // propagate estimation error covariance forward 76 P = A*P*trans(A) + Q; 77 78 // compute Kalman gain matrix 79 const matrix<double,states,measurements> K = P*trans(H)*pinv(H*P*trans(H) + R); 80 81 if (got_first_meas) 82 { 83 const matrix<double,measurements,1> res = z - H*xb; 84 // correct the current state estimate 85 x = xb + K*res; 86 } 87 else 88 { 89 // Since we don't have a previous state estimate at the start of filtering, 90 // we will just set the current state to whatever is indicated by the measurement 91 x = pinv(H)*z; 92 got_first_meas = true; 93 } 94 95 // propagate state forward in time 96 xb = A*x; 97 98 // update estimation error covariance since we got a measurement. 99 P = (identity_matrix<double,states>() - K*H)*P; 100 } 101 get_current_state()102 const matrix<double,states,1>& get_current_state( 103 ) const 104 { 105 return x; 106 } 107 get_predicted_next_state()108 const matrix<double,states,1>& get_predicted_next_state( 109 ) const 110 { 111 return xb; 112 } 113 get_current_estimation_error_covariance()114 const matrix<double,states,states>& get_current_estimation_error_covariance( 115 ) const 116 { 117 return P; 118 } 119 serialize(const kalman_filter & item,std::ostream & out)120 friend inline void serialize(const kalman_filter& item, std::ostream& out) 121 { 122 int version = 1; 123 serialize(version, out); 124 serialize(item.got_first_meas, out); 125 serialize(item.x, out); 126 serialize(item.xb, out); 127 serialize(item.P, out); 128 serialize(item.H, out); 129 serialize(item.A, out); 130 serialize(item.Q, out); 131 serialize(item.R, out); 132 } 133 deserialize(kalman_filter & item,std::istream & in)134 friend inline void deserialize(kalman_filter& item, std::istream& in) 135 { 136 int version = 0; 137 deserialize(version, in); 138 if (version != 1) 139 throw dlib::serialization_error("Unknown version number found while deserializing kalman_filter object."); 140 141 deserialize(item.got_first_meas, in); 142 deserialize(item.x, in); 143 deserialize(item.xb, in); 144 deserialize(item.P, in); 145 deserialize(item.H, in); 146 deserialize(item.A, in); 147 deserialize(item.Q, in); 148 deserialize(item.R, in); 149 } 150 151 private: 152 153 bool got_first_meas; 154 matrix<double,states,1> x, xb; 155 matrix<double,states,states> P; 156 157 matrix<double,measurements,states> H; 158 matrix<double,states,states> A; 159 matrix<double,states,states> Q; 160 matrix<double,measurements,measurements> R; 161 162 163 }; 164 165 // ---------------------------------------------------------------------------------------- 166 167 class momentum_filter 168 { 169 public: 170 momentum_filter(double meas_noise,double acc,double max_meas_dev)171 momentum_filter( 172 double meas_noise, 173 double acc, 174 double max_meas_dev 175 ) : 176 measurement_noise(meas_noise), 177 typical_acceleration(acc), 178 max_measurement_deviation(max_meas_dev) 179 { 180 DLIB_CASSERT(meas_noise >= 0); 181 DLIB_CASSERT(acc >= 0); 182 DLIB_CASSERT(max_meas_dev >= 0); 183 184 kal.set_observation_model({1, 0}); 185 kal.set_transition_model( {1, 1, 186 0, 1}); 187 kal.set_process_noise({0, 0, 188 0, typical_acceleration*typical_acceleration}); 189 190 kal.set_measurement_noise({measurement_noise*measurement_noise}); 191 } 192 193 momentum_filter() = default; 194 get_measurement_noise()195 double get_measurement_noise ( 196 ) const { return measurement_noise; } 197 get_typical_acceleration()198 double get_typical_acceleration ( 199 ) const { return typical_acceleration; } 200 get_max_measurement_deviation()201 double get_max_measurement_deviation ( 202 ) const { return max_measurement_deviation; } 203 reset()204 void reset() 205 { 206 *this = momentum_filter(measurement_noise, typical_acceleration, max_measurement_deviation); 207 } 208 get_predicted_next_position()209 double get_predicted_next_position( 210 ) const 211 { 212 return kal.get_predicted_next_state()(0); 213 } 214 operator()215 double operator()( 216 const double measured_position 217 ) 218 { 219 auto x = kal.get_predicted_next_state(); 220 const auto max_deviation = max_measurement_deviation*measurement_noise; 221 // Check if measured_position has suddenly jumped in value by a whole lot. This 222 // could happen if the velocity term experiences a much larger than normal 223 // acceleration, e.g. because the underlying object is doing a maneuver. If 224 // this happens then we clamp the state so that the predicted next value is no 225 // more than max_deviation away from measured_position at all times. 226 if (x(0) > measured_position + max_deviation) 227 { 228 x(0) = measured_position + max_deviation; 229 kal.set_state(x); 230 } 231 else if (x(0) < measured_position - max_deviation) 232 { 233 x(0) = measured_position - max_deviation; 234 kal.set_state(x); 235 } 236 237 kal.update({measured_position}); 238 239 return kal.get_current_state()(0); 240 } 241 242 friend std::ostream& operator << (std::ostream& out, const momentum_filter& item) 243 { 244 out << "measurement_noise: " << item.measurement_noise << "\n"; 245 out << "typical_acceleration: " << item.typical_acceleration << "\n"; 246 out << "max_measurement_deviation: " << item.max_measurement_deviation; 247 return out; 248 } 249 serialize(const momentum_filter & item,std::ostream & out)250 friend void serialize(const momentum_filter& item, std::ostream& out) 251 { 252 int version = 15; 253 serialize(version, out); 254 serialize(item.measurement_noise, out); 255 serialize(item.typical_acceleration, out); 256 serialize(item.max_measurement_deviation, out); 257 serialize(item.kal, out); 258 } 259 deserialize(momentum_filter & item,std::istream & in)260 friend void deserialize(momentum_filter& item, std::istream& in) 261 { 262 int version = 0; 263 deserialize(version, in); 264 if (version != 15) 265 throw serialization_error("Unexpected version found while deserializing momentum_filter."); 266 deserialize(item.measurement_noise, in); 267 deserialize(item.typical_acceleration, in); 268 deserialize(item.max_measurement_deviation, in); 269 deserialize(item.kal, in); 270 } 271 272 private: 273 274 double measurement_noise = 2; 275 double typical_acceleration = 0.1; 276 double max_measurement_deviation = 3; // nominally number of standard deviations 277 278 kalman_filter<2,1> kal; 279 }; 280 281 // ---------------------------------------------------------------------------------------- 282 283 momentum_filter find_optimal_momentum_filter ( 284 const std::vector<std::vector<double>>& sequences, 285 const double smoothness = 1 286 ); 287 288 // ---------------------------------------------------------------------------------------- 289 290 momentum_filter find_optimal_momentum_filter ( 291 const std::vector<double>& sequence, 292 const double smoothness = 1 293 ); 294 295 // ---------------------------------------------------------------------------------------- 296 297 class rect_filter 298 { 299 public: 300 rect_filter() = default; 301 rect_filter(double meas_noise,double acc,double max_meas_dev)302 rect_filter( 303 double meas_noise, 304 double acc, 305 double max_meas_dev 306 ) : rect_filter(momentum_filter(meas_noise, acc, max_meas_dev)) {} 307 rect_filter(const momentum_filter & filt)308 rect_filter( 309 const momentum_filter& filt 310 ) : 311 left(filt), 312 top(filt), 313 right(filt), 314 bottom(filt) 315 { 316 } 317 operator()318 drectangle operator()(const drectangle& r) 319 { 320 return drectangle(left(r.left()), 321 top(r.top()), 322 right(r.right()), 323 bottom(r.bottom())); 324 } 325 operator()326 drectangle operator()(const rectangle& r) 327 { 328 return drectangle(left(r.left()), 329 top(r.top()), 330 right(r.right()), 331 bottom(r.bottom())); 332 } 333 get_left()334 const momentum_filter& get_left () const { return left; } get_left()335 momentum_filter& get_left () { return left; } get_top()336 const momentum_filter& get_top () const { return top; } get_top()337 momentum_filter& get_top () { return top; } get_right()338 const momentum_filter& get_right () const { return right; } get_right()339 momentum_filter& get_right () { return right; } get_bottom()340 const momentum_filter& get_bottom () const { return bottom; } get_bottom()341 momentum_filter& get_bottom () { return bottom; } 342 serialize(const rect_filter & item,std::ostream & out)343 friend void serialize(const rect_filter& item, std::ostream& out) 344 { 345 int version = 123; 346 serialize(version, out); 347 serialize(item.left, out); 348 serialize(item.top, out); 349 serialize(item.right, out); 350 serialize(item.bottom, out); 351 } 352 deserialize(rect_filter & item,std::istream & in)353 friend void deserialize(rect_filter& item, std::istream& in) 354 { 355 int version = 0; 356 deserialize(version, in); 357 if (version != 123) 358 throw dlib::serialization_error("Unknown version number found while deserializing rect_filter object."); 359 deserialize(item.left, in); 360 deserialize(item.top, in); 361 deserialize(item.right, in); 362 deserialize(item.bottom, in); 363 } 364 365 private: 366 367 momentum_filter left, top, right, bottom; 368 }; 369 370 // ---------------------------------------------------------------------------------------- 371 372 rect_filter find_optimal_rect_filter ( 373 const std::vector<rectangle>& rects, 374 const double smoothness = 1 375 ); 376 377 // ---------------------------------------------------------------------------------------- 378 379 } 380 381 #endif // DLIB_KALMAN_FiLTER_Hh_ 382 383