1 /**
2 * @author Parikshit Ram (pram@cc.gatech.edu)
3 * @file methods/naive_bayes/nbc_main.cpp
4 *
5 * This program runs the Simple Naive Bayes Classifier.
6 *
7 * This classifier does parametric naive bayes classification assuming that the
8 * features are sampled from a Gaussian distribution.
9 *
10 * mlpack is free software; you may redistribute it and/or modify it under the
11 * terms of the 3-clause BSD license. You should have received a copy of the
12 * 3-clause BSD license along with mlpack. If not, see
13 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
14 */
15 #include <mlpack/prereqs.hpp>
16 #include <mlpack/core/util/io.hpp>
17 #include <mlpack/core/data/normalize_labels.hpp>
18 #include <mlpack/core/util/mlpack_main.hpp>
19
20 #include "naive_bayes_classifier.hpp"
21
22 using namespace mlpack;
23 using namespace mlpack::naive_bayes;
24 using namespace mlpack::util;
25 using namespace std;
26 using namespace arma;
27
28 // Program Name.
29 BINDING_NAME("Parametric Naive Bayes Classifier");
30
31 // Short description.
32 BINDING_SHORT_DESC(
33 "An implementation of the Naive Bayes Classifier, used for classification. "
34 "Given labeled data, an NBC model can be trained and saved, or, a "
35 "pre-trained model can be used for classification.");
36
37 // Long description.
38 BINDING_LONG_DESC(
39 "This program trains the Naive Bayes classifier on the given labeled "
40 "training set, or loads a model from the given model file, and then may use"
41 " that trained model to classify the points in a given test set."
42 "\n\n"
43 "The training set is specified with the " +
44 PRINT_PARAM_STRING("training") + " parameter. Labels may be either the "
45 "last row of the training set, or alternately the " +
46 PRINT_PARAM_STRING("labels") + " parameter may be specified to pass a "
47 "separate matrix of labels."
48 "\n\n"
49 "If training is not desired, a pre-existing model may be loaded with the " +
50 PRINT_PARAM_STRING("input_model") + " parameter."
51 "\n\n"
52 "\n\n"
53 "The " + PRINT_PARAM_STRING("incremental_variance") + " parameter can be "
54 "used to force the training to use an incremental algorithm for calculating"
55 " variance. This is slower, but can help avoid loss of precision in some "
56 "cases."
57 "\n\n"
58 "If classifying a test set is desired, the test set may be specified with "
59 "the " + PRINT_PARAM_STRING("test") + " parameter, and the classifications"
60 " may be saved with the " + PRINT_PARAM_STRING("predictions") +"predictions"
61 " parameter. If saving the trained model is desired, this may be "
62 "done with the " + PRINT_PARAM_STRING("output_model") + " output "
63 "parameter."
64 "\n\n"
65 "Note: the " + PRINT_PARAM_STRING("output") + " and " +
66 PRINT_PARAM_STRING("output_probs") + " parameters are deprecated and will "
67 "be removed in mlpack 4.0.0. Use " + PRINT_PARAM_STRING("predictions") +
68 " and " + PRINT_PARAM_STRING("probabilities") + " instead.");
69
70 // Example.
71 BINDING_EXAMPLE(
72 "For example, to train a Naive Bayes classifier on the dataset " +
73 PRINT_DATASET("data") + " with labels " + PRINT_DATASET("labels") + " "
74 "and save the model to " + PRINT_MODEL("nbc_model") + ", the following "
75 "command may be used:"
76 "\n\n" +
77 PRINT_CALL("nbc", "training", "data", "labels", "labels", "output_model",
78 "nbc_model") +
79 "\n\n"
80 "Then, to use " + PRINT_MODEL("nbc_model") + " to predict the classes of "
81 "the dataset " + PRINT_DATASET("test_set") + " and save the predicted "
82 "classes to " + PRINT_DATASET("predictions") + ", the following command "
83 "may be used:"
84 "\n\n" +
85 PRINT_CALL("nbc", "input_model", "nbc_model", "test", "test_set", "output",
86 "predictions"));
87
88 // See also...
89 BINDING_SEE_ALSO("@softmax_regression", "#softmax_regression");
90 BINDING_SEE_ALSO("@random_forest", "#random_forest");
91 BINDING_SEE_ALSO("Naive Bayes classifier on Wikipedia",
92 "https://en.wikipedia.org/wiki/Naive_Bayes_classifier");
93 BINDING_SEE_ALSO("mlpack::naive_bayes::NaiveBayesClassifier C++ class "
94 "documentation", "@doxygen/classmlpack_1_1naive__bayes_1_1"
95 "NaiveBayesClassifier.html");
96
97 // A struct for saving the model with mappings.
98 struct NBCModel
99 {
100 //! The model itself.
101 NaiveBayesClassifier<> nbc;
102 //! The mappings for labels.
103 Col<size_t> mappings;
104
105 //! Serialize the model.
106 template<typename Archive>
serializeNBCModel107 void serialize(Archive& ar, const unsigned int /* version */)
108 {
109 ar & BOOST_SERIALIZATION_NVP(nbc);
110 ar & BOOST_SERIALIZATION_NVP(mappings);
111 }
112 };
113
114 // Model loading/saving.
115 PARAM_MODEL_IN(NBCModel, "input_model", "Input Naive Bayes "
116 "model.", "m");
117 PARAM_MODEL_OUT(NBCModel, "output_model", "File to save trained "
118 "Naive Bayes model to.", "M");
119
120 // Training parameters.
121 PARAM_MATRIX_IN("training", "A matrix containing the training set.", "t");
122 PARAM_UROW_IN("labels", "A file containing labels for the training set.",
123 "l");
124 PARAM_FLAG("incremental_variance", "The variance of each class will be "
125 "calculated incrementally.", "I");
126
127 // Test parameters.
128 PARAM_MATRIX_IN("test", "A matrix containing the test set.", "T");
129 // The parameter 'output' is deprecated and will be removed in mlpack 4.
130 PARAM_UROW_OUT("output", "The matrix in which the predicted labels for the"
131 " test set will be written (deprecated).", "o");
132 PARAM_UROW_OUT("predictions", "The matrix in which the predicted labels for the"
133 " test set will be written.", "a");
134 // The parameter 'output_probs' is deprecated and can be removed in mlpack 4.
135 PARAM_MATRIX_OUT("output_probs", "The matrix in which the predicted probability"
136 " of labels for the test set will be written (deprecated).", "");
137 PARAM_MATRIX_OUT("probabilities", "The matrix in which the predicted"
138 " probability of labels for the test set will be written.", "p");
139
mlpackMain()140 static void mlpackMain()
141 {
142 // Check input parameters.
143 RequireOnlyOnePassed({ "training", "input_model" }, true);
144 ReportIgnoredParam({{ "training", false }}, "labels");
145 ReportIgnoredParam({{ "training", false }}, "incremental_variance");
146 RequireAtLeastOnePassed({ "output", "predictions", "output_model",
147 "output_probs", "probabilities" }, false, "no output will be saved");
148 ReportIgnoredParam({{ "test", false }}, "output");
149 ReportIgnoredParam({{ "test", false }}, "predictions");
150 if (IO::HasParam("input_model") && !IO::HasParam("test"))
151 Log::Warn << "No test set given; no task will be performed!" << std::endl;
152
153 // Either we have to train a model, or load a model.
154 NBCModel* model;
155 if (IO::HasParam("training"))
156 {
157 model = new NBCModel();
158 mat trainingData = std::move(IO::GetParam<mat>("training"));
159
160 Row<size_t> labels;
161
162 // Did the user pass in labels?
163 if (IO::HasParam("labels"))
164 {
165 // Load labels.
166 Row<size_t> rawLabels = std::move(IO::GetParam<Row<size_t>>("labels"));
167 data::NormalizeLabels(rawLabels, labels, model->mappings);
168 }
169 else
170 {
171 // Use the last row of the training data as the labels.
172 Log::Info << "Using last dimension of training data as training labels."
173 << endl;
174 data::NormalizeLabels(trainingData.row(trainingData.n_rows - 1), labels,
175 model->mappings);
176 // Remove the label row.
177 trainingData.shed_row(trainingData.n_rows - 1);
178 }
179 const bool incrementalVariance = IO::HasParam("incremental_variance");
180
181 Timer::Start("nbc_training");
182 model->nbc = NaiveBayesClassifier<>(trainingData, labels,
183 model->mappings.n_elem, incrementalVariance);
184 Timer::Stop("nbc_training");
185 }
186 else
187 {
188 // Load the model from file.
189 model = IO::GetParam<NBCModel*>("input_model");
190 }
191
192 // Do we need to do testing?
193 if (IO::HasParam("test"))
194 {
195 mat testingData = std::move(IO::GetParam<mat>("test"));
196
197 if (testingData.n_rows != model->nbc.Means().n_rows)
198 {
199 Log::Fatal << "Test data dimensionality (" << testingData.n_rows << ") "
200 << "must be the same as training data (" << model->nbc.Means().n_rows
201 << ")!" << std::endl;
202 }
203
204 // Time the running of the Naive Bayes Classifier.
205 Row<size_t> predictions;
206 mat probabilities;
207 Timer::Start("nbc_testing");
208 model->nbc.Classify(testingData, predictions, probabilities);
209 Timer::Stop("nbc_testing");
210
211 if (IO::HasParam("output") || IO::HasParam("predictions"))
212 {
213 // Un-normalize labels to prepare output.
214 Row<size_t> rawResults;
215 data::RevertLabels(predictions, model->mappings, rawResults);
216
217 if (IO::HasParam("predictions"))
218 IO::GetParam<Row<size_t>>("predictions") = rawResults;
219 if (IO::HasParam("output"))
220 IO::GetParam<Row<size_t>>("output") = std::move(rawResults);
221 }
222 if (IO::HasParam("output_probs") || IO::HasParam("probabilities"))
223 {
224 if (IO::HasParam("probabilities"))
225 IO::GetParam<mat>("probabilities") = probabilities;
226 if (IO::HasParam("output_probs"))
227 IO::GetParam<mat>("output_probs") = std::move(probabilities);
228 }
229 }
230
231 IO::GetParam<NBCModel*>("output_model") = model;
232 }
233