1 /**
2 * @file methods/dbscan/dbscan_impl.hpp
3 * @author Ryan Curtin
4 *
5 * Implementation of DBSCAN.
6 *
7 * mlpack is free software; you may redistribute it and/or modify it under the
8 * terms of the 3-clause BSD license. You should have received a copy of the
9 * 3-clause BSD license along with mlpack. If not, see
10 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11 */
12 #ifndef MLPACK_METHODS_DBSCAN_DBSCAN_IMPL_HPP
13 #define MLPACK_METHODS_DBSCAN_DBSCAN_IMPL_HPP
14
15 #include "dbscan.hpp"
16
17 namespace mlpack {
18 namespace dbscan {
19
20 /**
21 * Construct the DBSCAN object with the given parameters.
22 */
23 template<typename RangeSearchType, typename PointSelectionPolicy>
DBSCAN(const double epsilon,const size_t minPoints,const bool batchMode,RangeSearchType rangeSearch,PointSelectionPolicy pointSelector)24 DBSCAN<RangeSearchType, PointSelectionPolicy>::DBSCAN(
25 const double epsilon,
26 const size_t minPoints,
27 const bool batchMode,
28 RangeSearchType rangeSearch,
29 PointSelectionPolicy pointSelector) :
30 epsilon(epsilon),
31 minPoints(minPoints),
32 batchMode(batchMode),
33 rangeSearch(rangeSearch),
34 pointSelector(pointSelector)
35 {
36 // Nothing to do.
37 }
38
39 /**
40 * Performs DBSCAN clustering on the data, returning number of clusters
41 * and also the centroid of each cluster.
42 */
43 template<typename RangeSearchType, typename PointSelectionPolicy>
44 template<typename MatType>
Cluster(const MatType & data,arma::mat & centroids)45 size_t DBSCAN<RangeSearchType, PointSelectionPolicy>::Cluster(
46 const MatType& data,
47 arma::mat& centroids)
48 {
49 // These assignments will be thrown away, but there is no way to avoid
50 // calculating them.
51 arma::Row<size_t> assignments(data.n_cols);
52 assignments.fill(SIZE_MAX);
53
54 return Cluster(data, assignments, centroids);
55 }
56
57 /**
58 * Performs DBSCAN clustering on the data, returning number of clusters,
59 * the centroid of each cluster and also the list of cluster assignments.
60 */
61 template<typename RangeSearchType, typename PointSelectionPolicy>
62 template<typename MatType>
Cluster(const MatType & data,arma::Row<size_t> & assignments,arma::mat & centroids)63 size_t DBSCAN<RangeSearchType, PointSelectionPolicy>::Cluster(
64 const MatType& data,
65 arma::Row<size_t>& assignments,
66 arma::mat& centroids)
67 {
68 const size_t numClusters = Cluster(data, assignments);
69
70 // Now calculate the centroids.
71 centroids.zeros(data.n_rows, numClusters);
72
73 // Calculate number of points in each cluster.
74 arma::Row<size_t> counts;
75 counts.zeros(numClusters);
76 for (size_t i = 0; i < data.n_cols; ++i)
77 {
78 if (assignments[i] != SIZE_MAX)
79 {
80 centroids.col(assignments[i]) += data.col(i);
81 ++counts[assignments[i]];
82 }
83 }
84
85 // We should be guaranteed that the number of clusters is always greater than
86 // zero.
87 for (size_t i = 0; i < numClusters; ++i)
88 centroids.col(i) /= counts[i];
89
90 return numClusters;
91 }
92
93 /**
94 * Performs DBSCAN clustering on the data, returning the number of clusters and
95 * also the list of cluster assignments.
96 */
97 template<typename RangeSearchType, typename PointSelectionPolicy>
98 template<typename MatType>
Cluster(const MatType & data,arma::Row<size_t> & assignments)99 size_t DBSCAN<RangeSearchType, PointSelectionPolicy>::Cluster(
100 const MatType& data,
101 arma::Row<size_t>& assignments)
102 {
103 // Initialize the UnionFind object.
104 emst::UnionFind uf(data.n_cols);
105 rangeSearch.Train(data);
106
107 if (batchMode)
108 BatchCluster(data, uf);
109 else
110 PointwiseCluster(data, uf);
111
112 // Now set assignments.
113 assignments.set_size(data.n_cols);
114 for (size_t i = 0; i < data.n_cols; ++i)
115 assignments[i] = uf.Find(i);
116
117 // Get a count of all clusters.
118 const size_t numClusters = arma::max(assignments) + 1;
119 arma::Col<size_t> counts(numClusters, arma::fill::zeros);
120 for (size_t i = 0; i < assignments.n_elem; ++i)
121 counts[assignments[i]]++;
122
123 // Now assign clusters to new indices.
124 size_t currentCluster = 0;
125 arma::Col<size_t> newAssignments(numClusters);
126 for (size_t i = 0; i < counts.n_elem; ++i)
127 {
128 if (counts[i] >= minPoints)
129 newAssignments[i] = currentCluster++;
130 else
131 newAssignments[i] = SIZE_MAX;
132 }
133
134 // Now reassign.
135 for (size_t i = 0; i < assignments.n_elem; ++i)
136 assignments[i] = newAssignments[assignments[i]];
137
138 Log::Info << currentCluster << " clusters found." << std::endl;
139
140 return currentCluster;
141 }
142
143 /**
144 * Performs DBSCAN clustering on the data, returning the number of clusters and
145 * also the list of cluster assignments. This searches each point iteratively,
146 * and can save on RAM usage. It may be slower than the batch search with a
147 * dual-tree algorithm.
148 */
149 template<typename RangeSearchType, typename PointSelectionPolicy>
150 template<typename MatType>
PointwiseCluster(const MatType & data,emst::UnionFind & uf)151 void DBSCAN<RangeSearchType, PointSelectionPolicy>::PointwiseCluster(
152 const MatType& data,
153 emst::UnionFind& uf)
154 {
155 std::vector<std::vector<size_t>> neighbors;
156 std::vector<std::vector<double>> distances;
157
158 for (size_t i = 0; i < data.n_cols; ++i)
159 {
160 if (i % 10000 == 0 && i > 0)
161 Log::Info << "DBSCAN clustering on point " << i << "..." << std::endl;
162
163 // Do the range search for only this point.
164 rangeSearch.Search(data.col(i), math::Range(0.0, epsilon), neighbors,
165 distances);
166
167 // Union to all neighbors.
168 for (size_t j = 0; j < neighbors[0].size(); ++j)
169 uf.Union(i, neighbors[0][j]);
170 }
171 }
172
173 /**
174 * Performs DBSCAN clustering on the data, returning number of clusters
175 * and also the list of cluster assignments. This can perform search in batch,
176 * naive search).
177 */
178 template<typename RangeSearchType, typename PointSelectionPolicy>
179 template<typename MatType>
BatchCluster(const MatType & data,emst::UnionFind & uf)180 void DBSCAN<RangeSearchType, PointSelectionPolicy>::BatchCluster(
181 const MatType& data,
182 emst::UnionFind& uf)
183 {
184 // For each point, find the points in epsilon-nighborhood and their distances.
185 std::vector<std::vector<size_t>> neighbors;
186 std::vector<std::vector<double>> distances;
187 Log::Info << "Performing range search." << std::endl;
188 rangeSearch.Train(data);
189 rangeSearch.Search(data, math::Range(0.0, epsilon), neighbors, distances);
190 Log::Info << "Range search complete." << std::endl;
191
192 // Now loop over all points.
193 for (size_t i = 0; i < data.n_cols; ++i)
194 {
195 // Get the next index.
196 const size_t index = pointSelector.Select(i, data);
197 for (size_t j = 0; j < neighbors[index].size(); ++j)
198 uf.Union(index, neighbors[index][j]);
199 }
200 }
201
202 } // namespace dbscan
203 } // namespace mlpack
204
205 #endif
206