1 /**
2  * @file methods/emst/emst_main.cpp
3  * @author Bill March (march@gatech.edu)
4  *
5  * Calls the DualTreeBoruvka algorithm from dtb.hpp.
6  * Can optionally call naive Boruvka's method.
7  *
8  * For algorithm details, see:
9  *
10  * @code
11  * @inproceedings{
12  *   author = {March, W.B., Ram, P., and Gray, A.G.},
13  *   title = {{Fast Euclidean Minimum Spanning Tree: Algorithm, Analysis,
14  *      Applications.}},
15  *   booktitle = {Proceedings of the 16th ACM SIGKDD International Conference
16  *      on Knowledge Discovery and Data Mining}
17  *   series = {KDD 2010},
18  *   year = {2010}
19  * }
20  * @endcode
21  *
22  * mlpack is free software; you may redistribute it and/or modify it under the
23  * terms of the 3-clause BSD license.  You should have received a copy of the
24  * 3-clause BSD license along with mlpack.  If not, see
25  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
26  */
27 #include <mlpack/prereqs.hpp>
28 #include <mlpack/core/util/io.hpp>
29 #include <mlpack/core/util/mlpack_main.hpp>
30 
31 #include "dtb.hpp"
32 
33 // Program Name.
34 BINDING_NAME("Fast Euclidean Minimum Spanning Tree");
35 
36 // Short description.
37 BINDING_SHORT_DESC(
38     "An implementation of the Dual-Tree Boruvka algorithm for computing the "
39     "Euclidean minimum spanning tree of a set of input points.");
40 
41 // Long description.
42 BINDING_LONG_DESC(
43     "This program can compute the Euclidean minimum spanning tree of a set of "
44     "input points using the dual-tree Boruvka algorithm."
45     "\n\n"
46     "The set to calculate the minimum spanning tree of is specified with the " +
47     PRINT_PARAM_STRING("input") + " parameter, and the output may be saved with"
48     " the " + PRINT_PARAM_STRING("output") + " output parameter."
49     "\n\n"
50     "The " + PRINT_PARAM_STRING("leaf_size") + " parameter controls the leaf "
51     "size of the kd-tree that is used to calculate the minimum spanning tree, "
52     "and if the " + PRINT_PARAM_STRING("naive") + " option is given, then "
53     "brute-force search is used (this is typically much slower in low "
54     "dimensions).  The leaf size does not affect the results, but it may have "
55     "some effect on the runtime of the algorithm.");
56 
57 // Example.
58 BINDING_EXAMPLE(
59     "For example, the minimum spanning tree of the input dataset " +
60     PRINT_DATASET("data") + " can be calculated with a leaf size of 20 and "
61     "stored as " + PRINT_DATASET("spanning_tree") + " using the following "
62     "command:"
63     "\n\n" +
64     PRINT_CALL("emst", "input", "data", "leaf_size", 20, "output",
65         "spanning_tree") +
66     "\n\n"
67     "The output matrix is a three-dimensional matrix, where each row indicates "
68     "an edge.  The first dimension corresponds to the lesser index of the edge;"
69     " the second dimension corresponds to the greater index of the edge; and "
70     "the third column corresponds to the distance between the two points.");
71 
72 // See also...
73 BINDING_SEE_ALSO("EMST Tutorial", "@doxygen/emst_tutorial.html");
74 BINDING_SEE_ALSO("Minimum spanning tree on Wikipedia",
75         "https://en.wikipedia.org/wiki/Minimum_spanning_tree");
76 BINDING_SEE_ALSO("Fast Euclidean Minimum Spanning Tree: Algorithm, Analysis,"
77         " and Applications (pdf)", "http://www.mlpack.org/papers/emst.pdf");
78 BINDING_SEE_ALSO("mlpack::emst::DualTreeBoruvka class documentation",
79         "@doxygen/classmlpack_1_1emst_1_1DualTreeBoruvka.html");
80 
81 PARAM_MATRIX_IN_REQ("input", "Input data matrix.", "i");
82 PARAM_MATRIX_OUT("output", "Output data.  Stored as an edge list.", "o");
83 PARAM_FLAG("naive", "Compute the MST using O(n^2) naive algorithm.", "n");
84 PARAM_INT_IN("leaf_size", "Leaf size in the kd-tree.  One-element leaves give "
85     "the empirically best performance, but at the cost of greater memory "
86     "requirements.", "l", 1);
87 
88 using namespace mlpack;
89 using namespace mlpack::emst;
90 using namespace mlpack::tree;
91 using namespace mlpack::metric;
92 using namespace mlpack::util;
93 using namespace std;
94 
mlpackMain()95 static void mlpackMain()
96 {
97   RequireAtLeastOnePassed({ "output" }, false, "no output will be saved");
98 
99   arma::mat dataPoints = std::move(IO::GetParam<arma::mat>("input"));
100 
101   // Do naive computation if necessary.
102   if (IO::GetParam<bool>("naive"))
103   {
104     Log::Info << "Running naive algorithm." << endl;
105 
106     DualTreeBoruvka<> naive(dataPoints, true);
107 
108     arma::mat naiveResults;
109     naive.ComputeMST(naiveResults);
110 
111     if (IO::HasParam("output"))
112       IO::GetParam<arma::mat>("output") = std::move(naiveResults);
113   }
114   else
115   {
116     Log::Info << "Building tree.\n";
117 
118     // Check that the leaf size is reasonable.
119     RequireParamValue<int>("leaf_size", [](int x) { return x > 0; }, true,
120         "leaf size must be greater than or equal to 1");
121 
122     // Initialize the tree and get ready to compute the MST.  Compute the tree
123     // by hand.
124     const size_t leafSize = (size_t) IO::GetParam<int>("leaf_size");
125 
126     Timer::Start("tree_building");
127     std::vector<size_t> oldFromNew;
128     KDTree<EuclideanDistance, DTBStat, arma::mat> tree(dataPoints, oldFromNew,
129         leafSize);
130     metric::LMetric<2, true> metric;
131     Timer::Stop("tree_building");
132 
133     DualTreeBoruvka<> dtb(&tree, metric);
134 
135     // Run the DTB algorithm.
136     Log::Info << "Calculating minimum spanning tree." << endl;
137     arma::mat results;
138     dtb.ComputeMST(results);
139 
140     // Unmap the results.
141     arma::mat unmappedResults(results.n_rows, results.n_cols);
142     for (size_t i = 0; i < results.n_cols; ++i)
143     {
144       const size_t indexA = oldFromNew[size_t(results(0, i))];
145       const size_t indexB = oldFromNew[size_t(results(1, i))];
146 
147       if (indexA < indexB)
148       {
149         unmappedResults(0, i) = indexA;
150         unmappedResults(1, i) = indexB;
151       }
152       else
153       {
154         unmappedResults(0, i) = indexB;
155         unmappedResults(1, i) = indexA;
156       }
157 
158       unmappedResults(2, i) = results(2, i);
159     }
160 
161     if (IO::HasParam("output"))
162       IO::GetParam<arma::mat>("output") = std::move(unmappedResults);
163   }
164 }
165