1 /**
2  * @author Parikshit Ram (pram@cc.gatech.edu)
3  * @file methods/naive_bayes/nbc_main.cpp
4  *
5  * This program runs the Simple Naive Bayes Classifier.
6  *
7  * This classifier does parametric naive bayes classification assuming that the
8  * features are sampled from a Gaussian distribution.
9  *
10  * mlpack is free software; you may redistribute it and/or modify it under the
11  * terms of the 3-clause BSD license.  You should have received a copy of the
12  * 3-clause BSD license along with mlpack.  If not, see
13  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
14  */
15 #include <mlpack/prereqs.hpp>
16 #include <mlpack/core/util/io.hpp>
17 #include <mlpack/core/data/normalize_labels.hpp>
18 #include <mlpack/core/util/mlpack_main.hpp>
20 #include "naive_bayes_classifier.hpp"
22 using namespace mlpack;
23 using namespace mlpack::naive_bayes;
24 using namespace mlpack::util;
25 using namespace std;
26 using namespace arma;
28 // Program Name.
29 BINDING_NAME("Parametric Naive Bayes Classifier");
31 // Short description.
33     "An implementation of the Naive Bayes Classifier, used for classification. "
34     "Given labeled data, an NBC model can be trained and saved, or, a "
35     "pre-trained model can be used for classification.");
37 // Long description.
39     "This program trains the Naive Bayes classifier on the given labeled "
40     "training set, or loads a model from the given model file, and then may use"
41     " that trained model to classify the points in a given test set."
42     "\n\n"
43     "The training set is specified with the " +
44     PRINT_PARAM_STRING("training") + " parameter.  Labels may be either the "
45     "last row of the training set, or alternately the " +
46     PRINT_PARAM_STRING("labels") + " parameter may be specified to pass a "
47     "separate matrix of labels."
48     "\n\n"
49     "If training is not desired, a pre-existing model may be loaded with the " +
50     PRINT_PARAM_STRING("input_model") + " parameter."
51     "\n\n"
52     "\n\n"
53     "The " + PRINT_PARAM_STRING("incremental_variance") + " parameter can be "
54     "used to force the training to use an incremental algorithm for calculating"
55     " variance.  This is slower, but can help avoid loss of precision in some "
56     "cases."
57     "\n\n"
58     "If classifying a test set is desired, the test set may be specified with "
59     "the " + PRINT_PARAM_STRING("test") + " parameter, and the classifications"
60     " may be saved with the " + PRINT_PARAM_STRING("predictions") +"predictions"
61     "  parameter.  If saving the trained model is desired, this may be "
62     "done with the " + PRINT_PARAM_STRING("output_model") + " output "
63     "parameter."
64     "\n\n"
65     "Note: the " + PRINT_PARAM_STRING("output") + " and " +
66     PRINT_PARAM_STRING("output_probs") + " parameters are deprecated and will "
67     "be removed in mlpack 4.0.0.  Use " + PRINT_PARAM_STRING("predictions") +
68     " and " + PRINT_PARAM_STRING("probabilities") + " instead.");
70 // Example.
72     "For example, to train a Naive Bayes classifier on the dataset " +
73     PRINT_DATASET("data") + " with labels " + PRINT_DATASET("labels") + " "
74     "and save the model to " + PRINT_MODEL("nbc_model") + ", the following "
75     "command may be used:"
76     "\n\n" +
77     PRINT_CALL("nbc", "training", "data", "labels", "labels", "output_model",
78         "nbc_model") +
79     "\n\n"
80     "Then, to use " + PRINT_MODEL("nbc_model") + " to predict the classes of "
81     "the dataset " + PRINT_DATASET("test_set") + " and save the predicted "
82     "classes to " + PRINT_DATASET("predictions") + ", the following command "
83     "may be used:"
84     "\n\n" +
85     PRINT_CALL("nbc", "input_model", "nbc_model", "test", "test_set", "output",
86         "predictions"));
88 // See also...
89 BINDING_SEE_ALSO("@softmax_regression", "#softmax_regression");
90 BINDING_SEE_ALSO("@random_forest", "#random_forest");
91 BINDING_SEE_ALSO("Naive Bayes classifier on Wikipedia",
92         "https://en.wikipedia.org/wiki/Naive_Bayes_classifier");
93 BINDING_SEE_ALSO("mlpack::naive_bayes::NaiveBayesClassifier C++ class "
94         "documentation", "@doxygen/classmlpack_1_1naive__bayes_1_1"
95         "NaiveBayesClassifier.html");
97 // A struct for saving the model with mappings.
98 struct NBCModel
99 {
100   //! The model itself.
101   NaiveBayesClassifier<> nbc;
102   //! The mappings for labels.
103   Col<size_t> mappings;
105   //! Serialize the model.
106   template<typename Archive>
serializeNBCModel107   void serialize(Archive& ar, const unsigned int /* version */)
108   {
109     ar & BOOST_SERIALIZATION_NVP(nbc);
110     ar & BOOST_SERIALIZATION_NVP(mappings);
111   }
112 };
114 // Model loading/saving.
115 PARAM_MODEL_IN(NBCModel, "input_model", "Input Naive Bayes "
116     "model.", "m");
117 PARAM_MODEL_OUT(NBCModel, "output_model", "File to save trained "
118     "Naive Bayes model to.", "M");
120 // Training parameters.
121 PARAM_MATRIX_IN("training", "A matrix containing the training set.", "t");
122 PARAM_UROW_IN("labels", "A file containing labels for the training set.",
123     "l");
124 PARAM_FLAG("incremental_variance", "The variance of each class will be "
125     "calculated incrementally.", "I");
127 // Test parameters.
128 PARAM_MATRIX_IN("test", "A matrix containing the test set.", "T");
129 // The parameter 'output' is deprecated and will be removed in mlpack 4.
130 PARAM_UROW_OUT("output", "The matrix in which the predicted labels for the"
131     " test set will be written (deprecated).", "o");
132 PARAM_UROW_OUT("predictions", "The matrix in which the predicted labels for the"
133     " test set will be written.", "a");
134 // The parameter 'output_probs' is deprecated and can be removed in mlpack 4.
135 PARAM_MATRIX_OUT("output_probs", "The matrix in which the predicted probability"
136     " of labels for the test set will be written (deprecated).", "");
137 PARAM_MATRIX_OUT("probabilities", "The matrix in which the predicted"
138     " probability of labels for the test set will be written.", "p");
mlpackMain()140 static void mlpackMain()
141 {
142   // Check input parameters.
143   RequireOnlyOnePassed({ "training", "input_model" }, true);
144   ReportIgnoredParam({{ "training", false }}, "labels");
145   ReportIgnoredParam({{ "training", false }}, "incremental_variance");
146   RequireAtLeastOnePassed({ "output", "predictions", "output_model",
147       "output_probs", "probabilities" }, false, "no output will be saved");
148   ReportIgnoredParam({{ "test", false }}, "output");
149   ReportIgnoredParam({{ "test", false }}, "predictions");
150   if (IO::HasParam("input_model") && !IO::HasParam("test"))
151     Log::Warn << "No test set given; no task will be performed!" << std::endl;
153   // Either we have to train a model, or load a model.
154   NBCModel* model;
155   if (IO::HasParam("training"))
156   {
157     model = new NBCModel();
158     mat trainingData = std::move(IO::GetParam<mat>("training"));
160     Row<size_t> labels;
162     // Did the user pass in labels?
163     if (IO::HasParam("labels"))
164     {
165       // Load labels.
166       Row<size_t> rawLabels = std::move(IO::GetParam<Row<size_t>>("labels"));
167       data::NormalizeLabels(rawLabels, labels, model->mappings);
168     }
169     else
170     {
171       // Use the last row of the training data as the labels.
172       Log::Info << "Using last dimension of training data as training labels."
173           << endl;
174       data::NormalizeLabels(trainingData.row(trainingData.n_rows - 1), labels,
175           model->mappings);
176       // Remove the label row.
177       trainingData.shed_row(trainingData.n_rows - 1);
178     }
179     const bool incrementalVariance = IO::HasParam("incremental_variance");
181     Timer::Start("nbc_training");
182     model->nbc = NaiveBayesClassifier<>(trainingData, labels,
183         model->mappings.n_elem, incrementalVariance);
184     Timer::Stop("nbc_training");
185   }
186   else
187   {
188     // Load the model from file.
189     model = IO::GetParam<NBCModel*>("input_model");
190   }
192   // Do we need to do testing?
193   if (IO::HasParam("test"))
194   {
195     mat testingData = std::move(IO::GetParam<mat>("test"));
197     if (testingData.n_rows != model->nbc.Means().n_rows)
198     {
199       Log::Fatal << "Test data dimensionality (" << testingData.n_rows << ") "
200           << "must be the same as training data (" << model->nbc.Means().n_rows
201           << ")!" << std::endl;
202     }
204     // Time the running of the Naive Bayes Classifier.
205     Row<size_t> predictions;
206     mat probabilities;
207     Timer::Start("nbc_testing");
208     model->nbc.Classify(testingData, predictions, probabilities);
209     Timer::Stop("nbc_testing");
211     if (IO::HasParam("output") || IO::HasParam("predictions"))
212     {
213       // Un-normalize labels to prepare output.
214       Row<size_t> rawResults;
215       data::RevertLabels(predictions, model->mappings, rawResults);
217       if (IO::HasParam("predictions"))
218         IO::GetParam<Row<size_t>>("predictions") = rawResults;
219       if (IO::HasParam("output"))
220         IO::GetParam<Row<size_t>>("output") = std::move(rawResults);
221     }
222     if (IO::HasParam("output_probs") || IO::HasParam("probabilities"))
223     {
224       if (IO::HasParam("probabilities"))
225         IO::GetParam<mat>("probabilities") = probabilities;
226       if (IO::HasParam("output_probs"))
227         IO::GetParam<mat>("output_probs") = std::move(probabilities);
228     }
229   }
231   IO::GetParam<NBCModel*>("output_model") = model;
232 }