1 /**
2 * @file methods/emst/emst_main.cpp
3 * @author Bill March (march@gatech.edu)
4 *
5 * Calls the DualTreeBoruvka algorithm from dtb.hpp.
6 * Can optionally call naive Boruvka's method.
7 *
8 * For algorithm details, see:
9 *
10 * @code
11 * @inproceedings{
12 * author = {March, W.B., Ram, P., and Gray, A.G.},
13 * title = {{Fast Euclidean Minimum Spanning Tree: Algorithm, Analysis,
14 * Applications.}},
15 * booktitle = {Proceedings of the 16th ACM SIGKDD International Conference
16 * on Knowledge Discovery and Data Mining}
17 * series = {KDD 2010},
18 * year = {2010}
19 * }
20 * @endcode
21 *
22 * mlpack is free software; you may redistribute it and/or modify it under the
23 * terms of the 3-clause BSD license. You should have received a copy of the
24 * 3-clause BSD license along with mlpack. If not, see
25 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
26 */
27 #include <mlpack/prereqs.hpp>
28 #include <mlpack/core/util/io.hpp>
29 #include <mlpack/core/util/mlpack_main.hpp>
30
31 #include "dtb.hpp"
32
33 // Program Name.
34 BINDING_NAME("Fast Euclidean Minimum Spanning Tree");
35
36 // Short description.
37 BINDING_SHORT_DESC(
38 "An implementation of the Dual-Tree Boruvka algorithm for computing the "
39 "Euclidean minimum spanning tree of a set of input points.");
40
41 // Long description.
42 BINDING_LONG_DESC(
43 "This program can compute the Euclidean minimum spanning tree of a set of "
44 "input points using the dual-tree Boruvka algorithm."
45 "\n\n"
46 "The set to calculate the minimum spanning tree of is specified with the " +
47 PRINT_PARAM_STRING("input") + " parameter, and the output may be saved with"
48 " the " + PRINT_PARAM_STRING("output") + " output parameter."
49 "\n\n"
50 "The " + PRINT_PARAM_STRING("leaf_size") + " parameter controls the leaf "
51 "size of the kd-tree that is used to calculate the minimum spanning tree, "
52 "and if the " + PRINT_PARAM_STRING("naive") + " option is given, then "
53 "brute-force search is used (this is typically much slower in low "
54 "dimensions). The leaf size does not affect the results, but it may have "
55 "some effect on the runtime of the algorithm.");
56
57 // Example.
58 BINDING_EXAMPLE(
59 "For example, the minimum spanning tree of the input dataset " +
60 PRINT_DATASET("data") + " can be calculated with a leaf size of 20 and "
61 "stored as " + PRINT_DATASET("spanning_tree") + " using the following "
62 "command:"
63 "\n\n" +
64 PRINT_CALL("emst", "input", "data", "leaf_size", 20, "output",
65 "spanning_tree") +
66 "\n\n"
67 "The output matrix is a three-dimensional matrix, where each row indicates "
68 "an edge. The first dimension corresponds to the lesser index of the edge;"
69 " the second dimension corresponds to the greater index of the edge; and "
70 "the third column corresponds to the distance between the two points.");
71
72 // See also...
73 BINDING_SEE_ALSO("EMST Tutorial", "@doxygen/emst_tutorial.html");
74 BINDING_SEE_ALSO("Minimum spanning tree on Wikipedia",
75 "https://en.wikipedia.org/wiki/Minimum_spanning_tree");
76 BINDING_SEE_ALSO("Fast Euclidean Minimum Spanning Tree: Algorithm, Analysis,"
77 " and Applications (pdf)", "http://www.mlpack.org/papers/emst.pdf");
78 BINDING_SEE_ALSO("mlpack::emst::DualTreeBoruvka class documentation",
79 "@doxygen/classmlpack_1_1emst_1_1DualTreeBoruvka.html");
80
81 PARAM_MATRIX_IN_REQ("input", "Input data matrix.", "i");
82 PARAM_MATRIX_OUT("output", "Output data. Stored as an edge list.", "o");
83 PARAM_FLAG("naive", "Compute the MST using O(n^2) naive algorithm.", "n");
84 PARAM_INT_IN("leaf_size", "Leaf size in the kd-tree. One-element leaves give "
85 "the empirically best performance, but at the cost of greater memory "
86 "requirements.", "l", 1);
87
88 using namespace mlpack;
89 using namespace mlpack::emst;
90 using namespace mlpack::tree;
91 using namespace mlpack::metric;
92 using namespace mlpack::util;
93 using namespace std;
94
mlpackMain()95 static void mlpackMain()
96 {
97 RequireAtLeastOnePassed({ "output" }, false, "no output will be saved");
98
99 arma::mat dataPoints = std::move(IO::GetParam<arma::mat>("input"));
100
101 // Do naive computation if necessary.
102 if (IO::GetParam<bool>("naive"))
103 {
104 Log::Info << "Running naive algorithm." << endl;
105
106 DualTreeBoruvka<> naive(dataPoints, true);
107
108 arma::mat naiveResults;
109 naive.ComputeMST(naiveResults);
110
111 if (IO::HasParam("output"))
112 IO::GetParam<arma::mat>("output") = std::move(naiveResults);
113 }
114 else
115 {
116 Log::Info << "Building tree.\n";
117
118 // Check that the leaf size is reasonable.
119 RequireParamValue<int>("leaf_size", [](int x) { return x > 0; }, true,
120 "leaf size must be greater than or equal to 1");
121
122 // Initialize the tree and get ready to compute the MST. Compute the tree
123 // by hand.
124 const size_t leafSize = (size_t) IO::GetParam<int>("leaf_size");
125
126 Timer::Start("tree_building");
127 std::vector<size_t> oldFromNew;
128 KDTree<EuclideanDistance, DTBStat, arma::mat> tree(dataPoints, oldFromNew,
129 leafSize);
130 metric::LMetric<2, true> metric;
131 Timer::Stop("tree_building");
132
133 DualTreeBoruvka<> dtb(&tree, metric);
134
135 // Run the DTB algorithm.
136 Log::Info << "Calculating minimum spanning tree." << endl;
137 arma::mat results;
138 dtb.ComputeMST(results);
139
140 // Unmap the results.
141 arma::mat unmappedResults(results.n_rows, results.n_cols);
142 for (size_t i = 0; i < results.n_cols; ++i)
143 {
144 const size_t indexA = oldFromNew[size_t(results(0, i))];
145 const size_t indexB = oldFromNew[size_t(results(1, i))];
146
147 if (indexA < indexB)
148 {
149 unmappedResults(0, i) = indexA;
150 unmappedResults(1, i) = indexB;
151 }
152 else
153 {
154 unmappedResults(0, i) = indexB;
155 unmappedResults(1, i) = indexA;
156 }
157
158 unmappedResults(2, i) = results(2, i);
159 }
160
161 if (IO::HasParam("output"))
162 IO::GetParam<arma::mat>("output") = std::move(unmappedResults);
163 }
164 }
165