1 // Copyright (C) 2010  Davis E. King (davis@dlib.net)
2 // License: Boost Software License   See LICENSE.txt for the full license.
3 
4 #include "tester.h"
5 #include <dlib/svm_threaded.h>
6 #include <dlib/statistics.h>
7 #include <vector>
8 #include <sstream>
9 
10 namespace
11 {
12     using namespace test;
13     using namespace dlib;
14     using namespace std;
15     dlib::logger dlog("test.one_vs_one_trainer");
16 
17 
18     class test_one_vs_one_trainer : public tester
19     {
20         /*!
21             WHAT THIS OBJECT REPRESENTS
22                 This object represents a unit test.  When it is constructed
23                 it adds itself into the testing framework.
24         !*/
25     public:
test_one_vs_one_trainer()26         test_one_vs_one_trainer (
27         ) :
28             tester (
29                 "test_one_vs_one_trainer",       // the command line argument name for this test
30                 "Run tests on the one_vs_one_trainer stuff.", // the command line argument description
31                 0                     // the number of command line arguments for this test
32             )
33         {
34         }
35 
36 
37 
38         template <typename sample_type, typename label_type>
generate_data(std::vector<sample_type> & samples,std::vector<label_type> & labels)39         void generate_data (
40             std::vector<sample_type>& samples,
41             std::vector<label_type>& labels
42         )
43         {
44             const long num = 50;
45 
46             sample_type m;
47 
48             dlib::rand rnd;
49 
50 
51             // make some samples near the origin
52             double radius = 0.5;
53             for (long i = 0; i < num+10; ++i)
54             {
55                 double sign = 1;
56                 if (rnd.get_random_double() < 0.5)
57                     sign = -1;
58                 m(0) = 2*radius*rnd.get_random_double()-radius;
59                 m(1) = sign*sqrt(radius*radius - m(0)*m(0));
60 
61                 // add this sample to our set of samples we will run k-means
62                 samples.push_back(m);
63                 labels.push_back(1);
64             }
65 
66             // make some samples in a circle around the origin but far away
67             radius = 10.0;
68             for (long i = 0; i < num+20; ++i)
69             {
70                 double sign = 1;
71                 if (rnd.get_random_double() < 0.5)
72                     sign = -1;
73                 m(0) = 2*radius*rnd.get_random_double()-radius;
74                 m(1) = sign*sqrt(radius*radius - m(0)*m(0));
75 
76                 // add this sample to our set of samples we will run k-means
77                 samples.push_back(m);
78                 labels.push_back(2);
79             }
80 
81             // make some samples in a circle around the point (25,25)
82             radius = 4.0;
83             for (long i = 0; i < num+30; ++i)
84             {
85                 double sign = 1;
86                 if (rnd.get_random_double() < 0.5)
87                     sign = -1;
88                 m(0) = 2*radius*rnd.get_random_double()-radius;
89                 m(1) = sign*sqrt(radius*radius - m(0)*m(0));
90 
91                 // translate this point away from the origin
92                 m(0) += 25;
93                 m(1) += 25;
94 
95                 // add this sample to our set of samples we will run k-means
96                 samples.push_back(m);
97                 labels.push_back(3);
98             }
99         }
100 
101         template <typename label_type, typename scalar_type>
run_test()102         void run_test (
103         )
104         {
105             print_spinner();
106             typedef matrix<scalar_type,2,1> sample_type;
107 
108             std::vector<sample_type> samples, norm_samples;
109             std::vector<label_type> labels;
110 
111             // First, get our labeled set of training data
112             generate_data(samples, labels);
113 
114             typedef one_vs_one_trainer<any_trainer<sample_type,scalar_type>,label_type > ovo_trainer;
115 
116 
117             ovo_trainer trainer;
118 
119             typedef histogram_intersection_kernel<sample_type> hist_kernel;
120             typedef radial_basis_kernel<sample_type> rbf_kernel;
121 
122             // make the binary trainers and set some parameters
123             krr_trainer<rbf_kernel> rbf_trainer;
124             svm_nu_trainer<hist_kernel> hist_trainer;
125             rbf_trainer.set_kernel(rbf_kernel(0.1));
126 
127 
128             trainer.set_trainer(rbf_trainer);
129             trainer.set_trainer(hist_trainer, 1, 2);
130 
131             randomize_samples(samples, labels);
132             matrix<double> res = cross_validate_multiclass_trainer(trainer, samples, labels, 2);
133 
134             print_spinner();
135 
136             matrix<scalar_type> ans(3,3);
137             ans = 60,  0,  0,
138                   0, 70,  0,
139                   0,  0, 80;
140 
141             DLIB_TEST_MSG(ans == res, "res: \n" << res);
142 
143             // test using a normalized_function with a one_vs_one_decision_function
144             {
145                 trainer.set_trainer(hist_trainer, 1, 2);
146                 vector_normalizer<sample_type> normalizer;
147                 normalizer.train(samples);
148                 for (unsigned long i = 0; i < samples.size(); ++i)
149                     norm_samples.push_back(normalizer(samples[i]));
150                 normalized_function<one_vs_one_decision_function<ovo_trainer> > ndf;
151                 ndf.function = trainer.train(norm_samples, labels);
152                 ndf.normalizer = normalizer;
153                 DLIB_TEST(ndf(samples[0])  == labels[0]);
154                 DLIB_TEST(ndf(samples[40])  == labels[40]);
155                 DLIB_TEST(ndf(samples[90])  == labels[90]);
156                 DLIB_TEST(ndf(samples[120])  == labels[120]);
157                 trainer.set_trainer(hist_trainer, 1, 2);
158                 print_spinner();
159             }
160 
161 
162 
163 
164             one_vs_one_decision_function<ovo_trainer> df = trainer.train(samples, labels);
165 
166             DLIB_TEST(df.number_of_classes() == 3);
167 
168             DLIB_TEST(df(samples[0])  == labels[0]);
169             DLIB_TEST(df(samples[90])  == labels[90]);
170 
171 
172             one_vs_one_decision_function<ovo_trainer,
173                 decision_function<hist_kernel>,  // This is the output of the hist_trainer
174                 decision_function<rbf_kernel>    // This is the output of the rbf_trainer
175             > df2, df3;
176 
177 
178             df2 = df;
179             ofstream fout("df.dat", ios::binary);
180             serialize(df2, fout);
181             fout.close();
182 
183             // load the function back in from disk and store it in df3.
184             ifstream fin("df.dat", ios::binary);
185             deserialize(df3, fin);
186 
187 
188             DLIB_TEST(df3(samples[0])  == labels[0]);
189             DLIB_TEST(df3(samples[90])  == labels[90]);
190             res = test_multiclass_decision_function(df3, samples, labels);
191 
192             DLIB_TEST(res == ans);
193 
194 
195         }
196 
perform_test()197         void perform_test (
198         )
199         {
200             dlog << LINFO << "run_test<double,double>()";
201             run_test<double,double>();
202 
203             dlog << LINFO << "run_test<int,double>()";
204             run_test<int,double>();
205 
206             dlog << LINFO << "run_test<double,float>()";
207             run_test<double,float>();
208 
209             dlog << LINFO << "run_test<int,float>()";
210             run_test<int,float>();
211         }
212     };
213 
214     test_one_vs_one_trainer a;
215 
216 }
217 
218 
219