1 /**
2 * @file tests/emst_test.cpp
3 *
4 * Test file for EMST methods.
5 *
6 * mlpack is free software; you may redistribute it and/or modify it under the
7 * terms of the 3-clause BSD license. You should have received a copy of the
8 * 3-clause BSD license along with mlpack. If not, see
9 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
10 */
11 #include <mlpack/core.hpp>
12 #include <mlpack/methods/emst/dtb.hpp>
13 #include <boost/test/unit_test.hpp>
14 #include "test_tools.hpp"
15
16 #include <mlpack/core/tree/cover_tree.hpp>
17
18 using namespace mlpack;
19 using namespace mlpack::emst;
20 using namespace mlpack::tree;
21 using namespace mlpack::bound;
22 using namespace mlpack::metric;
23
24 BOOST_AUTO_TEST_SUITE(EMSTTest);
25
26 /**
27 * Simple emst test with small, synthetic dataset. This is an
28 * exhaustive test, which checks that each method for performing the calculation
29 * (dual-tree, naive) produces the correct results. The dataset is in one
30 * dimension for simplicity -- the correct functionality of distance functions
31 * is not tested here.
32 */
BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)33 BOOST_AUTO_TEST_CASE(ExhaustiveSyntheticTest)
34 {
35 // Set up our data.
36 arma::mat data(1, 11);
37 data[0] = 0.05; // Row addressing is unnecessary (they are all 0).
38 data[1] = 0.37;
39 data[2] = 0.15;
40 data[3] = 1.25;
41 data[4] = 5.05;
42 data[5] = -0.22;
43 data[6] = -2.00;
44 data[7] = -1.30;
45 data[8] = 0.45;
46 data[9] = 0.91;
47 data[10] = 1.00;
48
49 arma::mat results;
50
51 // Build the tree by hand to get a leaf size of 1.
52 typedef KDTree<EuclideanDistance, DTBStat, arma::mat> TreeType;
53 std::vector<size_t> oldFromNew;
54 std::vector<size_t> newFromOld;
55 TreeType tree(data, oldFromNew, newFromOld, 1);
56
57 // Create the DTB object and run the calculation.
58 DualTreeBoruvka<> dtb(&tree);
59 dtb.ComputeMST(results);
60
61 // Now the exhaustive check for correctness.
62 if (newFromOld[1] < newFromOld[8])
63 {
64 BOOST_REQUIRE_EQUAL(results(0, 0), newFromOld[1]);
65 BOOST_REQUIRE_EQUAL(results(1, 0), newFromOld[8]);
66 }
67 else
68 {
69 BOOST_REQUIRE_EQUAL(results(1, 0), newFromOld[1]);
70 BOOST_REQUIRE_EQUAL(results(0, 0), newFromOld[8]);
71 }
72 BOOST_REQUIRE_CLOSE(results(2, 0), 0.08, 1e-5);
73
74 if (newFromOld[9] < newFromOld[10])
75 {
76 BOOST_REQUIRE_EQUAL(results(0, 1), newFromOld[9]);
77 BOOST_REQUIRE_EQUAL(results(1, 1), newFromOld[10]);
78 }
79 else
80 {
81 BOOST_REQUIRE_EQUAL(results(1, 1), newFromOld[9]);
82 BOOST_REQUIRE_EQUAL(results(0, 1), newFromOld[10]);
83 }
84 BOOST_REQUIRE_CLOSE(results(2, 1), 0.09, 1e-5);
85
86 if (newFromOld[0] < newFromOld[2])
87 {
88 BOOST_REQUIRE_EQUAL(results(0, 2), newFromOld[0]);
89 BOOST_REQUIRE_EQUAL(results(1, 2), newFromOld[2]);
90 }
91 else
92 {
93 BOOST_REQUIRE_EQUAL(results(1, 2), newFromOld[0]);
94 BOOST_REQUIRE_EQUAL(results(0, 2), newFromOld[2]);
95 }
96 BOOST_REQUIRE_CLOSE(results(2, 2), 0.1, 1e-5);
97
98 if (newFromOld[1] < newFromOld[2])
99 {
100 BOOST_REQUIRE_EQUAL(results(0, 3), newFromOld[1]);
101 BOOST_REQUIRE_EQUAL(results(1, 3), newFromOld[2]);
102 }
103 else
104 {
105 BOOST_REQUIRE_EQUAL(results(1, 3), newFromOld[1]);
106 BOOST_REQUIRE_EQUAL(results(0, 3), newFromOld[2]);
107 }
108 BOOST_REQUIRE_CLOSE(results(2, 3), 0.22, 1e-5);
109
110 if (newFromOld[3] < newFromOld[10])
111 {
112 BOOST_REQUIRE_EQUAL(results(0, 4), newFromOld[3]);
113 BOOST_REQUIRE_EQUAL(results(1, 4), newFromOld[10]);
114 }
115 else
116 {
117 BOOST_REQUIRE_EQUAL(results(1, 4), newFromOld[3]);
118 BOOST_REQUIRE_EQUAL(results(0, 4), newFromOld[10]);
119 }
120 BOOST_REQUIRE_CLOSE(results(2, 4), 0.25, 1e-5);
121
122 if (newFromOld[0] < newFromOld[5])
123 {
124 BOOST_REQUIRE_EQUAL(results(0, 5), newFromOld[0]);
125 BOOST_REQUIRE_EQUAL(results(1, 5), newFromOld[5]);
126 }
127 else
128 {
129 BOOST_REQUIRE_EQUAL(results(1, 5), newFromOld[0]);
130 BOOST_REQUIRE_EQUAL(results(0, 5), newFromOld[5]);
131 }
132 BOOST_REQUIRE_CLOSE(results(2, 5), 0.27, 1e-5);
133
134 if (newFromOld[8] < newFromOld[9])
135 {
136 BOOST_REQUIRE_EQUAL(results(0, 6), newFromOld[8]);
137 BOOST_REQUIRE_EQUAL(results(1, 6), newFromOld[9]);
138 }
139 else
140 {
141 BOOST_REQUIRE_EQUAL(results(1, 6), newFromOld[8]);
142 BOOST_REQUIRE_EQUAL(results(0, 6), newFromOld[9]);
143 }
144 BOOST_REQUIRE_CLOSE(results(2, 6), 0.46, 1e-5);
145
146 if (newFromOld[6] < newFromOld[7])
147 {
148 BOOST_REQUIRE_EQUAL(results(0, 7), newFromOld[6]);
149 BOOST_REQUIRE_EQUAL(results(1, 7), newFromOld[7]);
150 }
151 else
152 {
153 BOOST_REQUIRE_EQUAL(results(1, 7), newFromOld[6]);
154 BOOST_REQUIRE_EQUAL(results(0, 7), newFromOld[7]);
155 }
156 BOOST_REQUIRE_CLOSE(results(2, 7), 0.7, 1e-5);
157
158 if (newFromOld[5] < newFromOld[7])
159 {
160 BOOST_REQUIRE_EQUAL(results(0, 8), newFromOld[5]);
161 BOOST_REQUIRE_EQUAL(results(1, 8), newFromOld[7]);
162 }
163 else
164 {
165 BOOST_REQUIRE_EQUAL(results(1, 8), newFromOld[5]);
166 BOOST_REQUIRE_EQUAL(results(0, 8), newFromOld[7]);
167 }
168 BOOST_REQUIRE_CLOSE(results(2, 8), 1.08, 1e-5);
169
170 if (newFromOld[3] < newFromOld[4])
171 {
172 BOOST_REQUIRE_EQUAL(results(0, 9), newFromOld[3]);
173 BOOST_REQUIRE_EQUAL(results(1, 9), newFromOld[4]);
174 }
175 else
176 {
177 BOOST_REQUIRE_EQUAL(results(1, 9), newFromOld[3]);
178 BOOST_REQUIRE_EQUAL(results(0, 9), newFromOld[4]);
179 }
180 BOOST_REQUIRE_CLOSE(results(2, 9), 3.8, 1e-5);
181 }
182
183 /**
184 * Test the dual tree method against the naive computation.
185 *
186 * Errors are produced if the results are not identical.
187 */
BOOST_AUTO_TEST_CASE(DualTreeVsNaive)188 BOOST_AUTO_TEST_CASE(DualTreeVsNaive)
189 {
190 arma::mat inputData;
191
192 // Hard-coded filename: bad!
193 // Code duplication: also bad!
194 if (!data::Load("test_data_3_1000.csv", inputData))
195 BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
196
197 // Set up matrices to work with.
198 arma::mat dualData = inputData;
199 arma::mat naiveData = inputData;
200
201 // Reset parameters from last test.
202 DualTreeBoruvka<> dtb(dualData);
203
204 arma::mat dualResults;
205 dtb.ComputeMST(dualResults);
206
207 // Set naive mode.
208 DualTreeBoruvka<> dtbNaive(naiveData, true);
209
210 arma::mat naiveResults;
211 dtbNaive.ComputeMST(naiveResults);
212
213 BOOST_REQUIRE_EQUAL(dualResults.n_cols, naiveResults.n_cols);
214 BOOST_REQUIRE_EQUAL(dualResults.n_rows, naiveResults.n_rows);
215
216 for (size_t i = 0; i < dualResults.n_cols; ++i)
217 {
218 BOOST_REQUIRE_EQUAL(dualResults(0, i), naiveResults(0, i));
219 BOOST_REQUIRE_EQUAL(dualResults(1, i), naiveResults(1, i));
220 BOOST_REQUIRE_CLOSE(dualResults(2, i), naiveResults(2, i), 1e-5);
221 }
222 }
223
224 /**
225 * Make sure the cover tree works fine.
226 */
BOOST_AUTO_TEST_CASE(CoverTreeTest)227 BOOST_AUTO_TEST_CASE(CoverTreeTest)
228 {
229 arma::mat inputData;
230 if (!data::Load("test_data_3_1000.csv", inputData))
231 BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
232
233 DualTreeBoruvka<> bst(inputData);
234 DualTreeBoruvka<EuclideanDistance, arma::mat, StandardCoverTree>
235 ct(inputData);
236
237 arma::mat bstResults;
238 arma::mat coverResults;
239
240 // Run the algorithms.
241 bst.ComputeMST(bstResults);
242 ct.ComputeMST(coverResults);
243
244 for (size_t i = 0; i < bstResults.n_cols; ++i)
245 {
246 BOOST_REQUIRE_EQUAL(bstResults(0, i), coverResults(0, i));
247 BOOST_REQUIRE_EQUAL(bstResults(1, i), coverResults(1, i));
248 BOOST_REQUIRE_CLOSE(bstResults(2, i), coverResults(2, i), 1e-5);
249 }
250 }
251
252 /**
253 * Test BinarySpaceTree with Ball Bound.
254 */
BOOST_AUTO_TEST_CASE(BallTreeTest)255 BOOST_AUTO_TEST_CASE(BallTreeTest)
256 {
257 arma::mat inputData;
258 if (!data::Load("test_data_3_1000.csv", inputData))
259 BOOST_FAIL("Cannot load test dataset test_data_3_1000.csv!");
260
261 // naive mode.
262 DualTreeBoruvka<> bst(inputData, true);
263 // Ball tree.
264 DualTreeBoruvka<EuclideanDistance, arma::mat, BallTree> ballt(inputData);
265
266 arma::mat bstResults;
267 arma::mat ballResults;
268
269 // Run the algorithms.
270 bst.ComputeMST(bstResults);
271 ballt.ComputeMST(ballResults);
272
273 for (size_t i = 0; i < bstResults.n_cols; ++i)
274 {
275 BOOST_REQUIRE_EQUAL(bstResults(0, i), ballResults(0, i));
276 BOOST_REQUIRE_EQUAL(bstResults(1, i), ballResults(1, i));
277 BOOST_REQUIRE_CLOSE(bstResults(2, i), ballResults(2, i), 1e-5);
278 }
279 }
280
281 BOOST_AUTO_TEST_SUITE_END();
282