1 // Copyright (C) 2010 Davis E. King (davis@dlib.net) 2 // License: Boost Software License See LICENSE.txt for the full license. 3 4 #include "tester.h" 5 #include <dlib/svm_threaded.h> 6 #include <dlib/statistics.h> 7 #include <vector> 8 #include <sstream> 9 10 namespace 11 { 12 using namespace test; 13 using namespace dlib; 14 using namespace std; 15 dlib::logger dlog("test.one_vs_one_trainer"); 16 17 18 class test_one_vs_one_trainer : public tester 19 { 20 /*! 21 WHAT THIS OBJECT REPRESENTS 22 This object represents a unit test. When it is constructed 23 it adds itself into the testing framework. 24 !*/ 25 public: test_one_vs_one_trainer()26 test_one_vs_one_trainer ( 27 ) : 28 tester ( 29 "test_one_vs_one_trainer", // the command line argument name for this test 30 "Run tests on the one_vs_one_trainer stuff.", // the command line argument description 31 0 // the number of command line arguments for this test 32 ) 33 { 34 } 35 36 37 38 template <typename sample_type, typename label_type> generate_data(std::vector<sample_type> & samples,std::vector<label_type> & labels)39 void generate_data ( 40 std::vector<sample_type>& samples, 41 std::vector<label_type>& labels 42 ) 43 { 44 const long num = 50; 45 46 sample_type m; 47 48 dlib::rand rnd; 49 50 51 // make some samples near the origin 52 double radius = 0.5; 53 for (long i = 0; i < num+10; ++i) 54 { 55 double sign = 1; 56 if (rnd.get_random_double() < 0.5) 57 sign = -1; 58 m(0) = 2*radius*rnd.get_random_double()-radius; 59 m(1) = sign*sqrt(radius*radius - m(0)*m(0)); 60 61 // add this sample to our set of samples we will run k-means 62 samples.push_back(m); 63 labels.push_back(1); 64 } 65 66 // make some samples in a circle around the origin but far away 67 radius = 10.0; 68 for (long i = 0; i < num+20; ++i) 69 { 70 double sign = 1; 71 if (rnd.get_random_double() < 0.5) 72 sign = -1; 73 m(0) = 2*radius*rnd.get_random_double()-radius; 74 m(1) = sign*sqrt(radius*radius - m(0)*m(0)); 75 76 // add this sample to our set of samples we will run k-means 77 samples.push_back(m); 78 labels.push_back(2); 79 } 80 81 // make some samples in a circle around the point (25,25) 82 radius = 4.0; 83 for (long i = 0; i < num+30; ++i) 84 { 85 double sign = 1; 86 if (rnd.get_random_double() < 0.5) 87 sign = -1; 88 m(0) = 2*radius*rnd.get_random_double()-radius; 89 m(1) = sign*sqrt(radius*radius - m(0)*m(0)); 90 91 // translate this point away from the origin 92 m(0) += 25; 93 m(1) += 25; 94 95 // add this sample to our set of samples we will run k-means 96 samples.push_back(m); 97 labels.push_back(3); 98 } 99 } 100 101 template <typename label_type, typename scalar_type> run_test()102 void run_test ( 103 ) 104 { 105 print_spinner(); 106 typedef matrix<scalar_type,2,1> sample_type; 107 108 std::vector<sample_type> samples, norm_samples; 109 std::vector<label_type> labels; 110 111 // First, get our labeled set of training data 112 generate_data(samples, labels); 113 114 typedef one_vs_one_trainer<any_trainer<sample_type,scalar_type>,label_type > ovo_trainer; 115 116 117 ovo_trainer trainer; 118 119 typedef histogram_intersection_kernel<sample_type> hist_kernel; 120 typedef radial_basis_kernel<sample_type> rbf_kernel; 121 122 // make the binary trainers and set some parameters 123 krr_trainer<rbf_kernel> rbf_trainer; 124 svm_nu_trainer<hist_kernel> hist_trainer; 125 rbf_trainer.set_kernel(rbf_kernel(0.1)); 126 127 128 trainer.set_trainer(rbf_trainer); 129 trainer.set_trainer(hist_trainer, 1, 2); 130 131 randomize_samples(samples, labels); 132 matrix<double> res = cross_validate_multiclass_trainer(trainer, samples, labels, 2); 133 134 print_spinner(); 135 136 matrix<scalar_type> ans(3,3); 137 ans = 60, 0, 0, 138 0, 70, 0, 139 0, 0, 80; 140 141 DLIB_TEST_MSG(ans == res, "res: \n" << res); 142 143 // test using a normalized_function with a one_vs_one_decision_function 144 { 145 trainer.set_trainer(hist_trainer, 1, 2); 146 vector_normalizer<sample_type> normalizer; 147 normalizer.train(samples); 148 for (unsigned long i = 0; i < samples.size(); ++i) 149 norm_samples.push_back(normalizer(samples[i])); 150 normalized_function<one_vs_one_decision_function<ovo_trainer> > ndf; 151 ndf.function = trainer.train(norm_samples, labels); 152 ndf.normalizer = normalizer; 153 DLIB_TEST(ndf(samples[0]) == labels[0]); 154 DLIB_TEST(ndf(samples[40]) == labels[40]); 155 DLIB_TEST(ndf(samples[90]) == labels[90]); 156 DLIB_TEST(ndf(samples[120]) == labels[120]); 157 trainer.set_trainer(hist_trainer, 1, 2); 158 print_spinner(); 159 } 160 161 162 163 164 one_vs_one_decision_function<ovo_trainer> df = trainer.train(samples, labels); 165 166 DLIB_TEST(df.number_of_classes() == 3); 167 168 DLIB_TEST(df(samples[0]) == labels[0]); 169 DLIB_TEST(df(samples[90]) == labels[90]); 170 171 172 one_vs_one_decision_function<ovo_trainer, 173 decision_function<hist_kernel>, // This is the output of the hist_trainer 174 decision_function<rbf_kernel> // This is the output of the rbf_trainer 175 > df2, df3; 176 177 178 df2 = df; 179 ofstream fout("df.dat", ios::binary); 180 serialize(df2, fout); 181 fout.close(); 182 183 // load the function back in from disk and store it in df3. 184 ifstream fin("df.dat", ios::binary); 185 deserialize(df3, fin); 186 187 188 DLIB_TEST(df3(samples[0]) == labels[0]); 189 DLIB_TEST(df3(samples[90]) == labels[90]); 190 res = test_multiclass_decision_function(df3, samples, labels); 191 192 DLIB_TEST(res == ans); 193 194 195 } 196 perform_test()197 void perform_test ( 198 ) 199 { 200 dlog << LINFO << "run_test<double,double>()"; 201 run_test<double,double>(); 202 203 dlog << LINFO << "run_test<int,double>()"; 204 run_test<int,double>(); 205 206 dlog << LINFO << "run_test<double,float>()"; 207 run_test<double,float>(); 208 209 dlog << LINFO << "run_test<int,float>()"; 210 run_test<int,float>(); 211 } 212 }; 213 214 test_one_vs_one_trainer a; 215 216 } 217 218 219