1 /**
2  * @file tests/emst_test.cpp
3  *
4  * Test file for EMST methods.
5  *
6  * mlpack is free software; you may redistribute it and/or modify it under the
7  * terms of the 3-clause BSD license.  You should have received a copy of the
8  * 3-clause BSD license along with mlpack.  If not, see
9  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
10  */
11 #include <mlpack/core.hpp>
12 #include <mlpack/methods/emst/dtb.hpp>
13 #include <boost/test/unit_test.hpp>
14 #include "test_tools.hpp"
15 
16 #include <mlpack/core/tree/cover_tree.hpp>
17 
18 using namespace mlpack;
19 using namespace mlpack::emst;
20 using namespace mlpack::tree;
21 using namespace mlpack::bound;
22 using namespace mlpack::metric;
23 
24 BOOST_AUTO_TEST_SUITE(EMSTTest);
25 
26 /**
27  * Simple emst test with small, synthetic dataset.  This is an
28  * exhaustive test, which checks that each method for performing the calculation
29  * (dual-tree, naive) produces the correct results.  The dataset is in one
30  * dimension for simplicity -- the correct functionality of distance functions
31  * is not tested here.
32  */
BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)33 BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
34 {
35   // Set up our data.
36   arma::mat data(1, 11);
37   data[0] = 0.05; // Row addressing is unnecessary (they are all 0).
38   data[1] = 0.37;
39   data[2] = 0.15;
40   data[3] = 1.25;
41   data[4] = 5.05;
42   data[5] = -0.22;
43   data[6] = -2.00;
44   data[7] = -1.30;
45   data[8] = 0.45;
46   data[9] = 0.91;
47   data[10] = 1.00;
48 
49   arma::mat results;
50 
51   // Build the tree by hand to get a leaf size of 1.
52   typedef KDTree<EuclideanDistance, DTBStat, arma::mat> TreeType;
53   std::vector<size_t> oldFromNew;
54   std::vector<size_t> newFromOld;
55   TreeType tree(data, oldFromNew, newFromOld, 1);
56 
57   // Create the DTB object and run the calculation.
58   DualTreeBoruvka<> dtb(&tree);
59   dtb.ComputeMST(results);
60 
61   // Now the exhaustive check for correctness.
62   if (newFromOld[1] < newFromOld[8])
63   {
64     BOOST_REQUIRE_EQUAL(results(0, 0), newFromOld[1]);
65     BOOST_REQUIRE_EQUAL(results(1, 0), newFromOld[8]);
66   }
67   else
68   {
69     BOOST_REQUIRE_EQUAL(results(1, 0), newFromOld[1]);
70     BOOST_REQUIRE_EQUAL(results(0, 0), newFromOld[8]);
71   }
72   BOOST_REQUIRE_CLOSE(results(2, 0), 0.08, 1e-5);
73 
74   if (newFromOld[9] < newFromOld[10])
75   {
76     BOOST_REQUIRE_EQUAL(results(0, 1), newFromOld[9]);
77     BOOST_REQUIRE_EQUAL(results(1, 1), newFromOld[10]);
78   }
79   else
80   {
81     BOOST_REQUIRE_EQUAL(results(1, 1), newFromOld[9]);
82     BOOST_REQUIRE_EQUAL(results(0, 1), newFromOld[10]);
83   }
84   BOOST_REQUIRE_CLOSE(results(2, 1), 0.09, 1e-5);
85 
86   if (newFromOld[0] < newFromOld[2])
87   {
88     BOOST_REQUIRE_EQUAL(results(0, 2), newFromOld[0]);
89     BOOST_REQUIRE_EQUAL(results(1, 2), newFromOld[2]);
90   }
91   else
92   {
93     BOOST_REQUIRE_EQUAL(results(1, 2), newFromOld[0]);
94     BOOST_REQUIRE_EQUAL(results(0, 2), newFromOld[2]);
95   }
96   BOOST_REQUIRE_CLOSE(results(2, 2), 0.1, 1e-5);
97 
98   if (newFromOld[1] < newFromOld[2])
99   {
100     BOOST_REQUIRE_EQUAL(results(0, 3), newFromOld[1]);
101     BOOST_REQUIRE_EQUAL(results(1, 3), newFromOld[2]);
102   }
103   else
104   {
105     BOOST_REQUIRE_EQUAL(results(1, 3), newFromOld[1]);
106     BOOST_REQUIRE_EQUAL(results(0, 3), newFromOld[2]);
107   }
108   BOOST_REQUIRE_CLOSE(results(2, 3), 0.22, 1e-5);
109 
110   if (newFromOld[3] < newFromOld[10])
111   {
112     BOOST_REQUIRE_EQUAL(results(0, 4), newFromOld[3]);
113     BOOST_REQUIRE_EQUAL(results(1, 4), newFromOld[10]);
114   }
115   else
116   {
117     BOOST_REQUIRE_EQUAL(results(1, 4), newFromOld[3]);
118     BOOST_REQUIRE_EQUAL(results(0, 4), newFromOld[10]);
119   }
120   BOOST_REQUIRE_CLOSE(results(2, 4), 0.25, 1e-5);
121 
122   if (newFromOld[0] < newFromOld[5])
123   {
124     BOOST_REQUIRE_EQUAL(results(0, 5), newFromOld[0]);
125     BOOST_REQUIRE_EQUAL(results(1, 5), newFromOld[5]);
126   }
127   else
128   {
129     BOOST_REQUIRE_EQUAL(results(1, 5), newFromOld[0]);
130     BOOST_REQUIRE_EQUAL(results(0, 5), newFromOld[5]);
131   }
132   BOOST_REQUIRE_CLOSE(results(2, 5), 0.27, 1e-5);
133 
134   if (newFromOld[8] < newFromOld[9])
135   {
136     BOOST_REQUIRE_EQUAL(results(0, 6), newFromOld[8]);
137     BOOST_REQUIRE_EQUAL(results(1, 6), newFromOld[9]);
138   }
139   else
140   {
141     BOOST_REQUIRE_EQUAL(results(1, 6), newFromOld[8]);
142     BOOST_REQUIRE_EQUAL(results(0, 6), newFromOld[9]);
143   }
144   BOOST_REQUIRE_CLOSE(results(2, 6), 0.46, 1e-5);
145 
146   if (newFromOld[6] < newFromOld[7])
147   {
148     BOOST_REQUIRE_EQUAL(results(0, 7), newFromOld[6]);
149     BOOST_REQUIRE_EQUAL(results(1, 7), newFromOld[7]);
150   }
151   else
152   {
153     BOOST_REQUIRE_EQUAL(results(1, 7), newFromOld[6]);
154     BOOST_REQUIRE_EQUAL(results(0, 7), newFromOld[7]);
155   }
156   BOOST_REQUIRE_CLOSE(results(2, 7), 0.7, 1e-5);
157 
158   if (newFromOld[5] < newFromOld[7])
159   {
160     BOOST_REQUIRE_EQUAL(results(0, 8), newFromOld[5]);
161     BOOST_REQUIRE_EQUAL(results(1, 8), newFromOld[7]);
162   }
163   else
164   {
165     BOOST_REQUIRE_EQUAL(results(1, 8), newFromOld[5]);
166     BOOST_REQUIRE_EQUAL(results(0, 8), newFromOld[7]);
167   }
168   BOOST_REQUIRE_CLOSE(results(2, 8), 1.08, 1e-5);
169 
170   if (newFromOld[3] < newFromOld[4])
171   {
172     BOOST_REQUIRE_EQUAL(results(0, 9), newFromOld[3]);
173     BOOST_REQUIRE_EQUAL(results(1, 9), newFromOld[4]);
174   }
175   else
176   {
177     BOOST_REQUIRE_EQUAL(results(1, 9), newFromOld[3]);
178     BOOST_REQUIRE_EQUAL(results(0, 9), newFromOld[4]);
179   }
180   BOOST_REQUIRE_CLOSE(results(2, 9), 3.8, 1e-5);
181 }
182 
183 /**
184  * Test the dual tree method against the naive computation.
185  *
186  * Errors are produced if the results are not identical.
187  */
BOOST_AUTO_TEST_CASE(DualTreeVsNaive)188 BOOST_AUTO_TEST_CASE(DualTreeVsNaive)
189 {
190   arma::mat inputData;
191 
192   // Hard-coded filename: bad!
193   // Code duplication: also bad!
194   if (!data::Load("test_data_3_1000.csv", inputData))
195     BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
196 
197   // Set up matrices to work with.
198   arma::mat dualData = inputData;
199   arma::mat naiveData = inputData;
200 
201   // Reset parameters from last test.
202   DualTreeBoruvka<> dtb(dualData);
203 
204   arma::mat dualResults;
205   dtb.ComputeMST(dualResults);
206 
207   // Set naive mode.
208   DualTreeBoruvka<> dtbNaive(naiveData, true);
209 
210   arma::mat naiveResults;
211   dtbNaive.ComputeMST(naiveResults);
212 
213   BOOST_REQUIRE_EQUAL(dualResults.n_cols, naiveResults.n_cols);
214   BOOST_REQUIRE_EQUAL(dualResults.n_rows, naiveResults.n_rows);
215 
216   for (size_t i = 0; i < dualResults.n_cols; ++i)
217   {
218     BOOST_REQUIRE_EQUAL(dualResults(0, i), naiveResults(0, i));
219     BOOST_REQUIRE_EQUAL(dualResults(1, i), naiveResults(1, i));
220     BOOST_REQUIRE_CLOSE(dualResults(2, i), naiveResults(2, i), 1e-5);
221   }
222 }
223 
224 /**
225  * Make sure the cover tree works fine.
226  */
BOOST_AUTO_TEST_CASE(CoverTreeTest)227 BOOST_AUTO_TEST_CASE(CoverTreeTest)
228 {
229   arma::mat inputData;
230   if (!data::Load("test_data_3_1000.csv", inputData))
231     BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
232 
233   DualTreeBoruvka<> bst(inputData);
234   DualTreeBoruvka<EuclideanDistance, arma::mat, StandardCoverTree>
235       ct(inputData);
236 
237   arma::mat bstResults;
238   arma::mat coverResults;
239 
240   // Run the algorithms.
241   bst.ComputeMST(bstResults);
242   ct.ComputeMST(coverResults);
243 
244   for (size_t i = 0; i < bstResults.n_cols; ++i)
245   {
246     BOOST_REQUIRE_EQUAL(bstResults(0, i), coverResults(0, i));
247     BOOST_REQUIRE_EQUAL(bstResults(1, i), coverResults(1, i));
248     BOOST_REQUIRE_CLOSE(bstResults(2, i), coverResults(2, i), 1e-5);
249   }
250 }
251 
252 /**
253  * Test BinarySpaceTree with Ball Bound.
254  */
BOOST_AUTO_TEST_CASE(BallTreeTest)255 BOOST_AUTO_TEST_CASE(BallTreeTest)
256 {
257   arma::mat inputData;
258   if (!data::Load("test_data_3_1000.csv", inputData))
259     BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
260 
261   // naive mode.
262   DualTreeBoruvka<> bst(inputData, true);
263   // Ball tree.
264   DualTreeBoruvka<EuclideanDistance, arma::mat, BallTree> ballt(inputData);
265 
266   arma::mat bstResults;
267   arma::mat ballResults;
268 
269   // Run the algorithms.
270   bst.ComputeMST(bstResults);
271   ballt.ComputeMST(ballResults);
272 
273   for (size_t i = 0; i < bstResults.n_cols; ++i)
274   {
275     BOOST_REQUIRE_EQUAL(bstResults(0, i), ballResults(0, i));
276     BOOST_REQUIRE_EQUAL(bstResults(1, i), ballResults(1, i));
277     BOOST_REQUIRE_CLOSE(bstResults(2, i), ballResults(2, i), 1e-5);
278   }
279 }
280 
281 BOOST_AUTO_TEST_SUITE_END();
282