1 // Copyright (C) 2012  Davis E. King (davis@dlib.net)
2 // License: Boost Software License   See LICENSE.txt for the full license.
3 
4 
5 #include <dlib/filtering.h>
6 #include <sstream>
7 #include <string>
8 #include <cstdlib>
9 #include <ctime>
10 #include <dlib/matrix.h>
11 #include <dlib/rand.h>
12 
13 #include "tester.h"
14 
15 namespace
16 {
17 
18     using namespace test;
19     using namespace dlib;
20     using namespace std;
21 
22     logger dlog("test.filtering");
23 
24 // ----------------------------------------------------------------------------------------
25 
26     template <typename filter_type>
test_filter(filter_type kf,int size)27     double test_filter (
28         filter_type kf,
29         int size
30     )
31     {
32         // This test has a point moving in a circle around the origin.  The point
33         // also gets a random bump in a random direction at each time step.
34 
35         running_stats<double> rs;
36 
37         dlib::rand rnd;
38         int count = 0;
39         const dlib::vector<double,3> z(0,0,1);
40         dlib::vector<double,2> p(10,10), temp;
41         for (int i = 0; i < size; ++i)
42         {
43             // move the point around in a circle
44             p += z.cross(p).normalize()/0.5;
45             // randomly drop measurements
46             if (rnd.get_random_double() < 0.7 || count < 4)
47             {
48                 // make a random bump
49                 dlib::vector<double,2> pp;
50                 pp.x() = rnd.get_random_gaussian()/3;
51                 pp.y() = rnd.get_random_gaussian()/3;
52 
53                 ++count;
54                 kf.update(p+pp);
55             }
56             else
57             {
58                 kf.update();
59                 dlog << LTRACE << "MISSED MEASUREMENT";
60             }
61             // figure out the next position
62             temp = (p+z.cross(p).normalize()/0.5);
63             const double error = length(temp - rowm(kf.get_predicted_next_state(),range(0,1)));
64             rs.add(error);
65 
66             dlog << LTRACE << temp << "("<< error << "): " << trans(kf.get_predicted_next_state());
67 
68             // test the serialization a few times.
69             if (count < 10)
70             {
71                 ostringstream sout;
72                 serialize(kf, sout);
73                 istringstream sin(sout.str());
74                 filter_type temp;
75                 deserialize(temp, sin);
76                 kf = temp;
77             }
78         }
79 
80 
81         return rs.mean();
82 
83     }
84 
85 // ----------------------------------------------------------------------------------------
86 
test_kalman_filter()87     void test_kalman_filter()
88     {
89         matrix<double,2,2> R;
90         R = 0.3, 0,
91         0,  0.3;
92 
93         // the variables in the state are
94         // x,y, x velocity, y velocity, x acceleration, and y acceleration
95         matrix<double,6,6> A;
96         A = 1, 0, 1, 0, 0, 0,
97         0, 1, 0, 1, 0, 0,
98         0, 0, 1, 0, 1, 0,
99         0, 0, 0, 1, 0, 1,
100         0, 0, 0, 0, 1, 0,
101         0, 0, 0, 0, 0, 1;
102 
103         // the measurements only tell us the positions
104         matrix<double,2,6> H;
105         H = 1, 0, 0, 0, 0, 0,
106         0, 1, 0, 0, 0, 0;
107 
108 
109         kalman_filter<6,2> kf;
110         kf.set_measurement_noise(R);
111         matrix<double> pn = 0.01*identity_matrix<double,6>();
112         kf.set_process_noise(pn);
113         kf.set_observation_model(H);
114         kf.set_transition_model(A);
115 
116         DLIB_TEST(equal(kf.get_observation_model() , H));
117         DLIB_TEST(equal(kf.get_transition_model() , A));
118         DLIB_TEST(equal(kf.get_measurement_noise() , R));
119         DLIB_TEST(equal(kf.get_process_noise() , pn));
120         DLIB_TEST(equal(kf.get_current_estimation_error_covariance() , identity_matrix(pn)));
121 
122         double kf_error = test_filter(kf, 300);
123 
124         dlog << LINFO << "kf error: "<< kf_error;
125         DLIB_TEST_MSG(kf_error < 0.75, kf_error);
126     }
127 
128 // ----------------------------------------------------------------------------------------
129 
test_rls_filter()130     void test_rls_filter()
131     {
132 
133         rls_filter rls(10, 0.99, 0.1);
134 
135         DLIB_TEST(rls.get_window_size() == 10);
136         DLIB_TEST(rls.get_forget_factor() == 0.99);
137         DLIB_TEST(rls.get_c() == 0.1);
138 
139         double rls_error = test_filter(rls, 1000);
140 
141         dlog << LINFO << "rls error: "<< rls_error;
142         DLIB_TEST_MSG(rls_error < 0.75, rls_error);
143     }
144 
145 // ----------------------------------------------------------------------------------------
146 
147     class filtering_tester : public tester
148     {
149     public:
filtering_tester()150         filtering_tester (
151         ) :
152             tester ("test_filtering",
153                     "Runs tests on the filtering stuff (rls and kalman filters).")
154         {}
155 
perform_test()156         void perform_test (
157         )
158         {
159             test_rls_filter();
160             test_kalman_filter();
161         }
162     } a;
163 
164 }
165 
166 
167