1 /**
2 * @author Parikshit Ram
3 * @file methods/gmm/gmm_train_main.cpp
4 *
5 * This program trains a mixture of Gaussians on a given data matrix.
6 *
7 * mlpack is free software; you may redistribute it and/or modify it under the
8 * terms of the 3-clause BSD license. You should have received a copy of the
9 * 3-clause BSD license along with mlpack. If not, see
10 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11 */
12 #include <mlpack/prereqs.hpp>
13 #include <mlpack/core/util/io.hpp>
14 #include <mlpack/core/util/mlpack_main.hpp>
15
16 #include "gmm.hpp"
17 #include "diagonal_gmm.hpp"
18 #include "no_constraint.hpp"
19 #include "diagonal_constraint.hpp"
20
21 #include <mlpack/methods/kmeans/refined_start.hpp>
22
23 using namespace mlpack;
24 using namespace mlpack::gmm;
25 using namespace mlpack::util;
26 using namespace mlpack::kmeans;
27 using namespace std;
28
29 // Program Name.
30 BINDING_NAME("Gaussian Mixture Model (GMM) Training");
31
32 // Short description.
33 BINDING_SHORT_DESC(
34 "An implementation of the EM algorithm for training Gaussian mixture "
35 "models (GMMs). Given a dataset, this can train a GMM for future use "
36 "with other tools.");
37
38 // Long description.
39 BINDING_LONG_DESC(
40 "This program takes a parametric estimate of a Gaussian mixture model (GMM)"
41 " using the EM algorithm to find the maximum likelihood estimate. The "
42 "model may be saved and reused by other mlpack GMM tools."
43 "\n\n"
44 "The input data to train on must be specified with the " +
45 PRINT_PARAM_STRING("input") + " parameter, and the number of Gaussians in "
46 "the model must be specified with the " + PRINT_PARAM_STRING("gaussians") +
47 " parameter. Optionally, many trials with different random "
48 "initializations may be run, and the result with highest log-likelihood on "
49 "the training data will be taken. The number of trials to run is specified"
50 " with the " + PRINT_PARAM_STRING("trials") + " parameter. By default, "
51 "only one trial is run."
52 "\n\n"
53 "The tolerance for convergence and maximum number of iterations of the EM "
54 "algorithm are specified with the " + PRINT_PARAM_STRING("tolerance") +
55 " and " + PRINT_PARAM_STRING("max_iterations") + " parameters, "
56 "respectively. The GMM may be initialized for training with another model,"
57 " specified with the " + PRINT_PARAM_STRING("input_model") + " parameter."
58 " Otherwise, the model is initialized by running k-means on the data. The "
59 "k-means clustering initialization can be controlled with the " +
60 PRINT_PARAM_STRING("kmeans_max_iterations") + ", " +
61 PRINT_PARAM_STRING("refined_start") + ", " +
62 PRINT_PARAM_STRING("samplings") + ", and " +
63 PRINT_PARAM_STRING("percentage") + " parameters. If " +
64 PRINT_PARAM_STRING("refined_start") + " is specified, then the "
65 "Bradley-Fayyad refined start initialization will be used. This can often "
66 "lead to better clustering results."
67 "\n\n"
68 "The 'diagonal_covariance' flag will cause the learned covariances to be "
69 "diagonal matrices. This significantly simplifies the model itself and "
70 "causes training to be faster, but restricts the ability to fit more "
71 "complex GMMs."
72 "\n\n"
73 "If GMM training fails with an error indicating that a covariance matrix "
74 "could not be inverted, make sure that the " +
75 PRINT_PARAM_STRING("no_force_positive") + " parameter is not "
76 "specified. Alternately, adding a small amount of Gaussian noise (using "
77 "the " + PRINT_PARAM_STRING("noise") + " parameter) to the entire dataset"
78 " may help prevent Gaussians with zero variance in a particular dimension, "
79 "which is usually the cause of non-invertible covariance matrices."
80 "\n\n"
81 "The " + PRINT_PARAM_STRING("no_force_positive") + " parameter, if set, "
82 "will avoid the checks after each iteration of the EM algorithm which "
83 "ensure that the covariance matrices are positive definite. Specifying "
84 "the flag can cause faster runtime, but may also cause non-positive "
85 "definite covariance matrices, which will cause the program to crash.");
86
87 // Example.
88 BINDING_EXAMPLE(
89 "As an example, to train a 6-Gaussian GMM on the data in " +
90 PRINT_DATASET("data") + " with a maximum of 100 iterations of EM and 3 "
91 "trials, saving the trained GMM to " + PRINT_MODEL("gmm") + ", the "
92 "following command can be used:"
93 "\n\n" +
94 PRINT_CALL("gmm_train", "input", "data", "gaussians", 6, "trials", 3,
95 "output_model", "gmm") +
96 "\n\n"
97 "To re-train that GMM on another set of data " + PRINT_DATASET("data2") +
98 ", the following command may be used: "
99 "\n\n" +
100 PRINT_CALL("gmm_train", "input_model", "gmm", "input", "data2",
101 "gaussians", 6, "output_model", "new_gmm"));
102
103 // See also...
104 BINDING_SEE_ALSO("@gmm_generate", "#gmm_generate");
105 BINDING_SEE_ALSO("@gmm_probability", "#gmm_probability");
106 BINDING_SEE_ALSO("Gaussian Mixture Models on Wikipedia",
107 "https://en.wikipedia.org/wiki/Mixture_model#Gaussian_mixture_model");
108 BINDING_SEE_ALSO("mlpack::gmm::GMM class documentation",
109 "@doxygen/classmlpack_1_1gmm_1_1GMM.html");
110
111 // Parameters for training.
112 PARAM_MATRIX_IN_REQ("input", "The training data on which the model will be "
113 "fit.", "i");
114 PARAM_INT_IN_REQ("gaussians", "Number of Gaussians in the GMM.", "g");
115
116 PARAM_INT_IN("seed", "Random seed. If 0, 'std::time(NULL)' is used.", "s", 0);
117 PARAM_INT_IN("trials", "Number of trials to perform in training GMM.", "t", 1);
118
119 // Parameters for EM algorithm.
120 PARAM_DOUBLE_IN("tolerance", "Tolerance for convergence of EM.", "T", 1e-10);
121 PARAM_FLAG("no_force_positive", "Do not force the covariance matrices to be "
122 "positive definite.", "P");
123 PARAM_INT_IN("max_iterations", "Maximum number of iterations of EM algorithm "
124 "(passing 0 will run until convergence).", "n", 250);
125 PARAM_FLAG("diagonal_covariance", "Force the covariance of the Gaussians to "
126 "be diagonal. This can accelerate training time significantly.", "d");
127
128 // Parameters for dataset modification.
129 PARAM_DOUBLE_IN("noise", "Variance of zero-mean Gaussian noise to add to data.",
130 "N", 0);
131
132 // Parameters for k-means initialization.
133 PARAM_INT_IN("kmeans_max_iterations", "Maximum number of iterations for the "
134 "k-means algorithm (used to initialize EM).", "k", 1000);
135 PARAM_FLAG("refined_start", "During the initialization, use refined initial "
136 "positions for k-means clustering (Bradley and Fayyad, 1998).", "r");
137 PARAM_INT_IN("samplings", "If using --refined_start, specify the number of "
138 "samplings used for initial points.", "S", 100);
139 PARAM_DOUBLE_IN("percentage", "If using --refined_start, specify the percentage"
140 " of the dataset used for each sampling (should be between 0.0 and 1.0).",
141 "p", 0.02);
142
143 // Parameters for model saving/loading.
144 PARAM_MODEL_IN(GMM, "input_model", "Initial input GMM model to start training "
145 "with.", "m");
146 PARAM_MODEL_OUT(GMM, "output_model", "Output for trained GMM model.", "M");
147
mlpackMain()148 static void mlpackMain()
149 {
150 // Check parameters and load data.
151 if (IO::GetParam<int>("seed") != 0)
152 math::RandomSeed((size_t) IO::GetParam<int>("seed"));
153 else
154 math::RandomSeed((size_t) std::time(NULL));
155
156 RequireParamValue<int>("gaussians", [](int x) { return x > 0; }, true,
157 "number of Gaussians must be positive");
158 const int gaussians = IO::GetParam<int>("gaussians");
159
160 RequireParamValue<int>("trials", [](int x) { return x > 0; }, true,
161 "trials must be greater than 0");
162
163 ReportIgnoredParam({{ "diagonal_covariance", true }}, "no_force_positive");
164 RequireAtLeastOnePassed({ "output_model" }, false, "no model will be saved");
165
166 RequireParamValue<double>("noise", [](double x) { return x >= 0.0; }, true,
167 "variance of noise must be greater than or equal to 0");
168
169 RequireParamValue<int>("max_iterations", [](int x) { return x >= 0; }, true,
170 "max_iterations must be greater than or equal to 0");
171 RequireParamValue<int>("kmeans_max_iterations", [](int x) { return x >= 0; },
172 true, "kmeans_max_iterations must be greater than or equal to 0");
173
174 arma::mat dataPoints = std::move(IO::GetParam<arma::mat>("input"));
175
176 // Do we need to add noise to the dataset?
177 if (IO::HasParam("noise"))
178 {
179 Timer::Start("noise_addition");
180 const double noise = IO::GetParam<double>("noise");
181 dataPoints += noise * arma::randn(dataPoints.n_rows, dataPoints.n_cols);
182 Log::Info << "Added zero-mean Gaussian noise with variance " << noise
183 << " to dataset." << std::endl;
184 Timer::Stop("noise_addition");
185 }
186
187 // Initialize GMM.
188 GMM* gmm = NULL;
189
190 if (IO::HasParam("input_model"))
191 {
192 gmm = IO::GetParam<GMM*>("input_model");
193
194 if (gmm->Dimensionality() != dataPoints.n_rows)
195 Log::Fatal << "Given input data (with " << PRINT_PARAM_STRING("input")
196 << ") has dimensionality " << dataPoints.n_rows << ", but the initial"
197 << " model (given with " << PRINT_PARAM_STRING("input_model")
198 << " has dimensionality " << gmm->Dimensionality() << "!" << endl;
199 }
200
201 // Gather parameters for EMFit object.
202 const size_t maxIterations = (size_t) IO::GetParam<int>("max_iterations");
203 const double tolerance = IO::GetParam<double>("tolerance");
204 const bool forcePositive = !IO::HasParam("no_force_positive");
205 const bool diagonalCovariance = IO::HasParam("diagonal_covariance");
206 const size_t kmeansMaxIterations =
207 (size_t) IO::GetParam<int>("kmeans_max_iterations");
208
209 // This gets a bit weird because we need different types depending on whether
210 // --refined_start is specified.
211 double likelihood;
212 if (IO::HasParam("refined_start"))
213 {
214 RequireParamValue<int>("samplings", [](int x) { return x > 0; }, true,
215 "number of samplings must be positive");
216 RequireParamValue<double>("percentage", [](double x) {
217 return x > 0.0 && x <= 1.0; }, true, "percentage to sample must be "
218 "be greater than 0.0 and less than or equal to 1.0");
219
220 // Initialize the GMM if needed. (We didn't do this earlier, because
221 // RequireParamValue() would leak the memory if the check failed.)
222 if (!IO::HasParam("input_model"))
223 gmm = new GMM(size_t(gaussians), dataPoints.n_rows);
224
225 const int samplings = IO::GetParam<int>("samplings");
226 const double percentage = IO::GetParam<double>("percentage");
227
228 typedef KMeans<metric::SquaredEuclideanDistance, RefinedStart> KMeansType;
229
230 KMeansType k(kmeansMaxIterations, metric::SquaredEuclideanDistance(),
231 RefinedStart(samplings, percentage));
232
233 // Depending on the value of forcePositive and diagonalCovariance, we have
234 // to use different types.
235 if (diagonalCovariance)
236 {
237 // Convert GMMs into DiagonalGMMs.
238 DiagonalGMM dgmm(gmm->Gaussians(), gmm->Dimensionality());
239 for (size_t i = 0; i < size_t(gaussians); ++i)
240 {
241 dgmm.Component(i).Mean() = gmm->Component(i).Mean();
242 dgmm.Component(i).Covariance(
243 std::move(arma::diagvec(gmm->Component(i).Covariance())));
244 }
245 dgmm.Weights() = gmm->Weights();
246
247 // Compute the parameters of the model using the EM algorithm.
248 Timer::Start("em");
249 EMFit<KMeansType, PositiveDefiniteConstraint,
250 distribution::DiagonalGaussianDistribution> em(maxIterations,
251 tolerance, k);
252
253 likelihood = dgmm.Train(dataPoints, IO::GetParam<int>("trials"), false,
254 em);
255 Timer::Stop("em");
256
257 // Convert DiagonalGMMs into GMMs.
258 for (size_t i = 0; i < size_t(gaussians); ++i)
259 {
260 gmm->Component(i).Mean() = dgmm.Component(i).Mean();
261 gmm->Component(i).Covariance(
262 arma::diagmat(dgmm.Component(i).Covariance()));
263 }
264 gmm->Weights() = dgmm.Weights();
265 }
266 else if (forcePositive)
267 {
268 // Compute the parameters of the model using the EM algorithm.
269 Timer::Start("em");
270 EMFit<KMeansType> em(maxIterations, tolerance, k);
271 likelihood = gmm->Train(dataPoints, IO::GetParam<int>("trials"), false,
272 em);
273 Timer::Stop("em");
274 }
275 else
276 {
277 // Compute the parameters of the model using the EM algorithm.
278 Timer::Start("em");
279 EMFit<KMeansType, NoConstraint> em(maxIterations, tolerance, k);
280 likelihood = gmm->Train(dataPoints, IO::GetParam<int>("trials"), false,
281 em);
282 Timer::Stop("em");
283 }
284 }
285 else
286 {
287 // Initialize the GMM if needed.
288 if (!IO::HasParam("input_model"))
289 gmm = new GMM(size_t(gaussians), dataPoints.n_rows);
290
291 // Depending on the value of forcePositive and diagonalCovariance, we have
292 // to use different types.
293 if (diagonalCovariance)
294 {
295 // Convert GMMs into DiagonalGMMs.
296 DiagonalGMM dgmm(gmm->Gaussians(), gmm->Dimensionality());
297 for (size_t i = 0; i < size_t(gaussians); ++i)
298 {
299 dgmm.Component(i).Mean() = gmm->Component(i).Mean();
300 dgmm.Component(i).Covariance(
301 std::move(arma::diagvec(gmm->Component(i).Covariance())));
302 }
303 dgmm.Weights() = gmm->Weights();
304
305 // Compute the parameters of the model using the EM algorithm.
306 Timer::Start("em");
307 EMFit<KMeans<>, PositiveDefiniteConstraint,
308 distribution::DiagonalGaussianDistribution> em(maxIterations,
309 tolerance, KMeans<>(kmeansMaxIterations));
310
311 likelihood = dgmm.Train(dataPoints, IO::GetParam<int>("trials"), false,
312 em);
313 Timer::Stop("em");
314
315 // Convert DiagonalGMMs into GMMs.
316 for (size_t i = 0; i < size_t(gaussians); ++i)
317 {
318 gmm->Component(i).Mean() = dgmm.Component(i).Mean();
319 gmm->Component(i).Covariance(
320 arma::diagmat(dgmm.Component(i).Covariance()));
321 }
322 gmm->Weights() = dgmm.Weights();
323 }
324 else if (forcePositive)
325 {
326 // Compute the parameters of the model using the EM algorithm.
327 Timer::Start("em");
328 EMFit<> em(maxIterations, tolerance, KMeans<>(kmeansMaxIterations));
329 likelihood = gmm->Train(dataPoints, IO::GetParam<int>("trials"), false,
330 em);
331 Timer::Stop("em");
332 }
333 else
334 {
335 // Compute the parameters of the model using the EM algorithm.
336 Timer::Start("em");
337 KMeans<> k(kmeansMaxIterations);
338 EMFit<KMeans<>, NoConstraint> em(maxIterations, tolerance, k);
339 likelihood = gmm->Train(dataPoints, IO::GetParam<int>("trials"), false,
340 em);
341 Timer::Stop("em");
342 }
343 }
344
345 Log::Info << "Log-likelihood of estimate: " << likelihood << "." << endl;
346
347 IO::GetParam<GMM*>("output_model") = gmm;
348 }
349