1 // Copyright (C) 2011  Davis E. King (davis@dlib.net)
2 // License: Boost Software License   See LICENSE.txt for the full license.
3 
4 #include <sstream>
5 #include <string>
6 #include <cstdlib>
7 #include <ctime>
8 #include <dlib/svm.h>
9 #include <dlib/matrix.h>
10 
11 #include "tester.h"
12 
13 namespace
14 {
15     using namespace test;
16     using namespace dlib;
17     using namespace std;
18 
19     logger dlog("test.kmeans");
20 
21     dlib::rand rnd;
22 
23     template <typename sample_type>
run_test(const std::vector<sample_type> & seed_centers)24     void run_test(
25         const std::vector<sample_type>& seed_centers
26     )
27     {
28         print_spinner();
29 
30 
31         sample_type samp;
32 
33         std::vector<sample_type> samples;
34 
35 
36         for (unsigned long j = 0; j < seed_centers.size(); ++j)
37         {
38             for (int i = 0; i < 250; ++i)
39             {
40                 samp = randm(seed_centers[0].size(),1,rnd) - 0.5;
41                 samples.push_back(samp + seed_centers[j]);
42             }
43         }
44 
45         randomize_samples(samples);
46 
47         {
48             std::vector<sample_type> centers;
49             pick_initial_centers(seed_centers.size(), centers, samples, linear_kernel<sample_type>());
50 
51             find_clusters_using_kmeans(samples, centers);
52 
53             DLIB_TEST(centers.size() == seed_centers.size());
54 
55             std::vector<int> hits(centers.size(),0);
56             for (unsigned long i = 0; i < samples.size(); ++i)
57             {
58                 unsigned long best_idx = 0;
59                 double best_dist = 1e100;
60                 for (unsigned long j = 0; j < centers.size(); ++j)
61                 {
62                     if (length(samples[i] - centers[j]) < best_dist)
63                     {
64                         best_dist = length(samples[i] - centers[j]);
65                         best_idx = j;
66                     }
67                 }
68                 hits[best_idx]++;
69             }
70 
71             for (unsigned long i = 0; i < hits.size(); ++i)
72             {
73                 DLIB_TEST(hits[i] == 250);
74             }
75         }
76         {
77             std::vector<sample_type> centers;
78             pick_initial_centers(seed_centers.size(), centers, samples, linear_kernel<sample_type>());
79 
80             find_clusters_using_angular_kmeans(samples, centers);
81 
82             DLIB_TEST(centers.size() == seed_centers.size());
83 
84             std::vector<int> hits(centers.size(),0);
85             for (unsigned long i = 0; i < samples.size(); ++i)
86             {
87                 unsigned long best_idx = 0;
88                 double best_dist = 1e100;
89                 for (unsigned long j = 0; j < centers.size(); ++j)
90                 {
91                     if (length(samples[i] - centers[j]) < best_dist)
92                     {
93                         best_dist = length(samples[i] - centers[j]);
94                         best_idx = j;
95                     }
96                 }
97                 hits[best_idx]++;
98             }
99 
100             for (unsigned long i = 0; i < hits.size(); ++i)
101             {
102                 DLIB_TEST(hits[i] == 250);
103             }
104         }
105     }
106 
107 
108     class test_kmeans : public tester
109     {
110     public:
test_kmeans()111         test_kmeans (
112         ) :
113             tester ("test_kmeans",
114                     "Runs tests on the find_clusters_using_kmeans() function.")
115         {}
116 
perform_test()117         void perform_test (
118         )
119         {
120             {
121                 dlog << LINFO << "test dlib::vector<double,2>";
122                 typedef dlib::vector<double,2> sample_type;
123                 std::vector<sample_type> seed_centers;
124                 seed_centers.push_back(sample_type(10,10));
125                 seed_centers.push_back(sample_type(10,-10));
126                 seed_centers.push_back(sample_type(-10,10));
127                 seed_centers.push_back(sample_type(-10,-10));
128 
129                 run_test(seed_centers);
130             }
131             {
132                 dlog << LINFO << "test dlib::vector<double,2>";
133                 typedef dlib::vector<float,2> sample_type;
134                 std::vector<sample_type> seed_centers;
135                 seed_centers.push_back(sample_type(10,10));
136                 seed_centers.push_back(sample_type(10,-10));
137                 seed_centers.push_back(sample_type(-10,10));
138                 seed_centers.push_back(sample_type(-10,-10));
139 
140                 run_test(seed_centers);
141             }
142             {
143                 dlog << LINFO << "test dlib::matrix<double,3,1>";
144                 typedef dlib::matrix<double,3,1> sample_type;
145                 std::vector<sample_type> seed_centers;
146                 sample_type samp;
147                 samp = 10,10,0; seed_centers.push_back(samp);
148                 samp = -10,10,1; seed_centers.push_back(samp);
149                 samp = -10,-10,2; seed_centers.push_back(samp);
150 
151                 run_test(seed_centers);
152             }
153 
154 
155         }
156     } a;
157 
158 
159 
160 }
161 
162 
163 
164