1 // Copyright (C) 2012  Davis E. King (davis@dlib.net)
2 // License: Boost Software License   See LICENSE.txt for the full license.
3 #ifndef DLIB_KALMAN_FiLTER_Hh_
4 #define DLIB_KALMAN_FiLTER_Hh_
5 
6 #include "kalman_filter_abstract.h"
7 #include "../matrix.h"
8 #include "../geometry.h"
9 
10 namespace dlib
11 {
12 
13 // ----------------------------------------------------------------------------------------
14 
15     template <
16         long states,
17         long measurements
18         >
19     class kalman_filter
20     {
21     public:
22 
kalman_filter()23         kalman_filter()
24         {
25             H = 0;
26             A = 0;
27             Q = 0;
28             R = 0;
29             x = 0;
30             xb = 0;
31             P = identity_matrix<double>(states);
32             got_first_meas = false;
33         }
34 
set_observation_model(const matrix<double,measurements,states> & H_)35         void set_observation_model ( const matrix<double,measurements,states>& H_) { H = H_; }
set_transition_model(const matrix<double,states,states> & A_)36         void set_transition_model  ( const matrix<double,states,states>& A_) { A = A_; }
set_process_noise(const matrix<double,states,states> & Q_)37         void set_process_noise     ( const matrix<double,states,states>& Q_) { Q = Q_; }
set_measurement_noise(const matrix<double,measurements,measurements> & R_)38         void set_measurement_noise ( const matrix<double,measurements,measurements>& R_) { R = R_; }
set_estimation_error_covariance(const matrix<double,states,states> & P_)39         void set_estimation_error_covariance( const matrix<double,states,states>& P_) { P = P_; }
set_state(const matrix<double,states,1> & xb_)40         void set_state             ( const matrix<double,states,1>& xb_)
41         {
42             xb = xb_;
43             if (!got_first_meas)
44             {
45                 x = xb_;
46                 got_first_meas = true;
47             }
48         }
49 
get_observation_model()50         const matrix<double,measurements,states>& get_observation_model (
51         ) const { return H; }
52 
get_transition_model()53         const matrix<double,states,states>& get_transition_model (
54         ) const { return A; }
55 
get_process_noise()56         const matrix<double,states,states>& get_process_noise (
57         ) const { return Q; }
58 
get_measurement_noise()59         const matrix<double,measurements,measurements>& get_measurement_noise (
60         ) const { return R; }
61 
update()62         void update (
63         )
64         {
65             // propagate estimation error covariance forward
66             P = A*P*trans(A) + Q;
67 
68             // propagate state forward
69             x = xb;
70             xb = A*x;
71         }
72 
update(const matrix<double,measurements,1> & z)73         void update (const matrix<double,measurements,1>& z)
74         {
75             // propagate estimation error covariance forward
76             P = A*P*trans(A) + Q;
77 
78             // compute Kalman gain matrix
79             const matrix<double,states,measurements> K = P*trans(H)*pinv(H*P*trans(H) + R);
80 
81             if (got_first_meas)
82             {
83                 const matrix<double,measurements,1> res = z - H*xb;
84                 // correct the current state estimate
85                 x = xb + K*res;
86             }
87             else
88             {
89                 // Since we don't have a previous state estimate at the start of filtering,
90                 // we will just set the current state to whatever is indicated by the measurement
91                 x = pinv(H)*z;
92                 got_first_meas = true;
93             }
94 
95             // propagate state forward in time
96             xb = A*x;
97 
98             // update estimation error covariance since we got a measurement.
99             P = (identity_matrix<double,states>() - K*H)*P;
100         }
101 
get_current_state()102         const matrix<double,states,1>& get_current_state(
103         ) const
104         {
105             return x;
106         }
107 
get_predicted_next_state()108         const matrix<double,states,1>& get_predicted_next_state(
109         ) const
110         {
111             return xb;
112         }
113 
get_current_estimation_error_covariance()114         const matrix<double,states,states>& get_current_estimation_error_covariance(
115         ) const
116         {
117             return P;
118         }
119 
serialize(const kalman_filter & item,std::ostream & out)120         friend inline void serialize(const kalman_filter& item, std::ostream& out)
121         {
122             int version = 1;
123             serialize(version, out);
124             serialize(item.got_first_meas, out);
125             serialize(item.x, out);
126             serialize(item.xb, out);
127             serialize(item.P, out);
128             serialize(item.H, out);
129             serialize(item.A, out);
130             serialize(item.Q, out);
131             serialize(item.R, out);
132         }
133 
deserialize(kalman_filter & item,std::istream & in)134         friend inline void deserialize(kalman_filter& item, std::istream& in)
135         {
136             int version = 0;
137             deserialize(version, in);
138             if (version != 1)
139                 throw dlib::serialization_error("Unknown version number found while deserializing kalman_filter object.");
140 
141             deserialize(item.got_first_meas, in);
142             deserialize(item.x, in);
143             deserialize(item.xb, in);
144             deserialize(item.P, in);
145             deserialize(item.H, in);
146             deserialize(item.A, in);
147             deserialize(item.Q, in);
148             deserialize(item.R, in);
149         }
150 
151     private:
152 
153         bool got_first_meas;
154         matrix<double,states,1> x, xb;
155         matrix<double,states,states> P;
156 
157         matrix<double,measurements,states> H;
158         matrix<double,states,states> A;
159         matrix<double,states,states> Q;
160         matrix<double,measurements,measurements> R;
161 
162 
163     };
164 
165 // ----------------------------------------------------------------------------------------
166 
167     class momentum_filter
168     {
169     public:
170 
momentum_filter(double meas_noise,double acc,double max_meas_dev)171         momentum_filter(
172             double meas_noise,
173             double acc,
174             double max_meas_dev
175         ) :
176             measurement_noise(meas_noise),
177             typical_acceleration(acc),
178             max_measurement_deviation(max_meas_dev)
179         {
180             DLIB_CASSERT(meas_noise >= 0);
181             DLIB_CASSERT(acc >= 0);
182             DLIB_CASSERT(max_meas_dev >= 0);
183 
184             kal.set_observation_model({1, 0});
185             kal.set_transition_model( {1, 1,
186                 0, 1});
187             kal.set_process_noise({0, 0,
188                 0, typical_acceleration*typical_acceleration});
189 
190             kal.set_measurement_noise({measurement_noise*measurement_noise});
191         }
192 
193         momentum_filter() = default;
194 
get_measurement_noise()195         double get_measurement_noise (
196         ) const { return measurement_noise; }
197 
get_typical_acceleration()198         double get_typical_acceleration (
199         ) const { return typical_acceleration; }
200 
get_max_measurement_deviation()201         double get_max_measurement_deviation (
202         ) const { return max_measurement_deviation; }
203 
reset()204         void reset()
205         {
206             *this = momentum_filter(measurement_noise, typical_acceleration, max_measurement_deviation);
207         }
208 
get_predicted_next_position()209         double get_predicted_next_position(
210         ) const
211         {
212             return kal.get_predicted_next_state()(0);
213         }
214 
operator()215         double operator()(
216             const double measured_position
217         )
218         {
219             auto x = kal.get_predicted_next_state();
220             const auto max_deviation = max_measurement_deviation*measurement_noise;
221             // Check if measured_position has suddenly jumped in value by a whole lot. This
222             // could happen if the velocity term experiences a much larger than normal
223             // acceleration, e.g.  because the underlying object is doing a maneuver.  If
224             // this happens then we clamp the state so that the predicted next value is no
225             // more than max_deviation away from measured_position at all times.
226             if (x(0) > measured_position + max_deviation)
227             {
228                 x(0) = measured_position + max_deviation;
229                 kal.set_state(x);
230             }
231             else if (x(0) < measured_position - max_deviation)
232             {
233                 x(0) = measured_position - max_deviation;
234                 kal.set_state(x);
235             }
236 
237             kal.update({measured_position});
238 
239             return kal.get_current_state()(0);
240         }
241 
242         friend std::ostream& operator << (std::ostream& out, const momentum_filter& item)
243         {
244             out << "measurement_noise:         " << item.measurement_noise << "\n";
245             out << "typical_acceleration:      " << item.typical_acceleration << "\n";
246             out << "max_measurement_deviation: " << item.max_measurement_deviation;
247             return out;
248         }
249 
serialize(const momentum_filter & item,std::ostream & out)250         friend void serialize(const momentum_filter& item, std::ostream& out)
251         {
252             int version = 15;
253             serialize(version, out);
254             serialize(item.measurement_noise, out);
255             serialize(item.typical_acceleration, out);
256             serialize(item.max_measurement_deviation, out);
257             serialize(item.kal, out);
258         }
259 
deserialize(momentum_filter & item,std::istream & in)260         friend void deserialize(momentum_filter& item, std::istream& in)
261         {
262             int version = 0;
263             deserialize(version, in);
264             if (version != 15)
265                 throw serialization_error("Unexpected version found while deserializing momentum_filter.");
266             deserialize(item.measurement_noise, in);
267             deserialize(item.typical_acceleration, in);
268             deserialize(item.max_measurement_deviation, in);
269             deserialize(item.kal, in);
270         }
271 
272     private:
273 
274         double measurement_noise = 2;
275         double typical_acceleration = 0.1;
276         double max_measurement_deviation = 3; // nominally number of standard deviations
277 
278         kalman_filter<2,1> kal;
279     };
280 
281 // ----------------------------------------------------------------------------------------
282 
283     momentum_filter find_optimal_momentum_filter (
284         const std::vector<std::vector<double>>& sequences,
285         const double smoothness = 1
286     );
287 
288 // ----------------------------------------------------------------------------------------
289 
290     momentum_filter find_optimal_momentum_filter (
291         const std::vector<double>& sequence,
292         const double smoothness = 1
293     );
294 
295 // ----------------------------------------------------------------------------------------
296 
297     class rect_filter
298     {
299     public:
300         rect_filter() = default;
301 
rect_filter(double meas_noise,double acc,double max_meas_dev)302         rect_filter(
303             double meas_noise,
304             double acc,
305             double max_meas_dev
306         ) : rect_filter(momentum_filter(meas_noise, acc, max_meas_dev)) {}
307 
rect_filter(const momentum_filter & filt)308         rect_filter(
309             const momentum_filter& filt
310         ) :
311             left(filt),
312             top(filt),
313             right(filt),
314             bottom(filt)
315         {
316         }
317 
operator()318         drectangle operator()(const drectangle& r)
319         {
320             return drectangle(left(r.left()),
321                             top(r.top()),
322                             right(r.right()),
323                             bottom(r.bottom()));
324         }
325 
operator()326         drectangle operator()(const rectangle& r)
327         {
328             return drectangle(left(r.left()),
329                             top(r.top()),
330                             right(r.right()),
331                             bottom(r.bottom()));
332         }
333 
get_left()334         const momentum_filter& get_left   () const { return left; }
get_left()335         momentum_filter&       get_left   ()       { return left; }
get_top()336         const momentum_filter& get_top    () const { return top; }
get_top()337         momentum_filter&       get_top    ()       { return top; }
get_right()338         const momentum_filter& get_right  () const { return right; }
get_right()339         momentum_filter&       get_right  ()       { return right; }
get_bottom()340         const momentum_filter& get_bottom () const { return bottom; }
get_bottom()341         momentum_filter&       get_bottom ()       { return bottom; }
342 
serialize(const rect_filter & item,std::ostream & out)343         friend void serialize(const rect_filter& item, std::ostream& out)
344         {
345             int version = 123;
346             serialize(version, out);
347             serialize(item.left, out);
348             serialize(item.top, out);
349             serialize(item.right, out);
350             serialize(item.bottom, out);
351         }
352 
deserialize(rect_filter & item,std::istream & in)353         friend void deserialize(rect_filter& item, std::istream& in)
354         {
355             int version = 0;
356             deserialize(version, in);
357             if (version != 123)
358                 throw dlib::serialization_error("Unknown version number found while deserializing rect_filter object.");
359             deserialize(item.left, in);
360             deserialize(item.top, in);
361             deserialize(item.right, in);
362             deserialize(item.bottom, in);
363         }
364 
365     private:
366 
367         momentum_filter left, top, right, bottom;
368     };
369 
370 // ----------------------------------------------------------------------------------------
371 
372     rect_filter find_optimal_rect_filter (
373         const std::vector<rectangle>& rects,
374         const double smoothness = 1
375     );
376 
377 // ----------------------------------------------------------------------------------------
378 
379 }
380 
381 #endif // DLIB_KALMAN_FiLTER_Hh_
382 
383