1 // Copyright (C) 2012 Davis E. King (davis@dlib.net) 2 // License: Boost Software License See LICENSE.txt for the full license. 3 4 5 #include <dlib/filtering.h> 6 #include <sstream> 7 #include <string> 8 #include <cstdlib> 9 #include <ctime> 10 #include <dlib/matrix.h> 11 #include <dlib/rand.h> 12 13 #include "tester.h" 14 15 namespace 16 { 17 18 using namespace test; 19 using namespace dlib; 20 using namespace std; 21 22 logger dlog("test.filtering"); 23 24 // ---------------------------------------------------------------------------------------- 25 26 template <typename filter_type> test_filter(filter_type kf,int size)27 double test_filter ( 28 filter_type kf, 29 int size 30 ) 31 { 32 // This test has a point moving in a circle around the origin. The point 33 // also gets a random bump in a random direction at each time step. 34 35 running_stats<double> rs; 36 37 dlib::rand rnd; 38 int count = 0; 39 const dlib::vector<double,3> z(0,0,1); 40 dlib::vector<double,2> p(10,10), temp; 41 for (int i = 0; i < size; ++i) 42 { 43 // move the point around in a circle 44 p += z.cross(p).normalize()/0.5; 45 // randomly drop measurements 46 if (rnd.get_random_double() < 0.7 || count < 4) 47 { 48 // make a random bump 49 dlib::vector<double,2> pp; 50 pp.x() = rnd.get_random_gaussian()/3; 51 pp.y() = rnd.get_random_gaussian()/3; 52 53 ++count; 54 kf.update(p+pp); 55 } 56 else 57 { 58 kf.update(); 59 dlog << LTRACE << "MISSED MEASUREMENT"; 60 } 61 // figure out the next position 62 temp = (p+z.cross(p).normalize()/0.5); 63 const double error = length(temp - rowm(kf.get_predicted_next_state(),range(0,1))); 64 rs.add(error); 65 66 dlog << LTRACE << temp << "("<< error << "): " << trans(kf.get_predicted_next_state()); 67 68 // test the serialization a few times. 69 if (count < 10) 70 { 71 ostringstream sout; 72 serialize(kf, sout); 73 istringstream sin(sout.str()); 74 filter_type temp; 75 deserialize(temp, sin); 76 kf = temp; 77 } 78 } 79 80 81 return rs.mean(); 82 83 } 84 85 // ---------------------------------------------------------------------------------------- 86 test_kalman_filter()87 void test_kalman_filter() 88 { 89 matrix<double,2,2> R; 90 R = 0.3, 0, 91 0, 0.3; 92 93 // the variables in the state are 94 // x,y, x velocity, y velocity, x acceleration, and y acceleration 95 matrix<double,6,6> A; 96 A = 1, 0, 1, 0, 0, 0, 97 0, 1, 0, 1, 0, 0, 98 0, 0, 1, 0, 1, 0, 99 0, 0, 0, 1, 0, 1, 100 0, 0, 0, 0, 1, 0, 101 0, 0, 0, 0, 0, 1; 102 103 // the measurements only tell us the positions 104 matrix<double,2,6> H; 105 H = 1, 0, 0, 0, 0, 0, 106 0, 1, 0, 0, 0, 0; 107 108 109 kalman_filter<6,2> kf; 110 kf.set_measurement_noise(R); 111 matrix<double> pn = 0.01*identity_matrix<double,6>(); 112 kf.set_process_noise(pn); 113 kf.set_observation_model(H); 114 kf.set_transition_model(A); 115 116 DLIB_TEST(equal(kf.get_observation_model() , H)); 117 DLIB_TEST(equal(kf.get_transition_model() , A)); 118 DLIB_TEST(equal(kf.get_measurement_noise() , R)); 119 DLIB_TEST(equal(kf.get_process_noise() , pn)); 120 DLIB_TEST(equal(kf.get_current_estimation_error_covariance() , identity_matrix(pn))); 121 122 double kf_error = test_filter(kf, 300); 123 124 dlog << LINFO << "kf error: "<< kf_error; 125 DLIB_TEST_MSG(kf_error < 0.75, kf_error); 126 } 127 128 // ---------------------------------------------------------------------------------------- 129 test_rls_filter()130 void test_rls_filter() 131 { 132 133 rls_filter rls(10, 0.99, 0.1); 134 135 DLIB_TEST(rls.get_window_size() == 10); 136 DLIB_TEST(rls.get_forget_factor() == 0.99); 137 DLIB_TEST(rls.get_c() == 0.1); 138 139 double rls_error = test_filter(rls, 1000); 140 141 dlog << LINFO << "rls error: "<< rls_error; 142 DLIB_TEST_MSG(rls_error < 0.75, rls_error); 143 } 144 145 // ---------------------------------------------------------------------------------------- 146 147 class filtering_tester : public tester 148 { 149 public: filtering_tester()150 filtering_tester ( 151 ) : 152 tester ("test_filtering", 153 "Runs tests on the filtering stuff (rls and kalman filters).") 154 {} 155 perform_test()156 void perform_test ( 157 ) 158 { 159 test_rls_filter(); 160 test_kalman_filter(); 161 } 162 } a; 163 164 } 165 166 167