1 /**
2 * @file methods/lars/lars_main.cpp
3 * @author Nishant Mehta
4 *
5 * Executable for LARS.
6 *
7 * mlpack is free software; you may redistribute it and/or modify it under the
8 * terms of the 3-clause BSD license. You should have received a copy of the
9 * 3-clause BSD license along with mlpack. If not, see
10 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11 */
12 #include <mlpack/prereqs.hpp>
13 #include <mlpack/core/util/io.hpp>
14 #include <mlpack/core/util/mlpack_main.hpp>
15
16 #include "lars.hpp"
17
18 using namespace arma;
19 using namespace std;
20 using namespace mlpack;
21 using namespace mlpack::regression;
22 using namespace mlpack::util;
23
24 // Program Name.
25 BINDING_NAME("LARS");
26
27 // Short description.
28 BINDING_SHORT_DESC(
29 "An implementation of Least Angle Regression (Stagewise/laSso), also known"
30 " as LARS. This can train a LARS/LASSO/Elastic Net model and use that "
31 "model or a pre-trained model to output regression predictions for a test "
32 "set.");
33
34 // Long description.
35 BINDING_LONG_DESC(
36 "An implementation of LARS: Least Angle Regression (Stagewise/laSso). "
37 "This is a stage-wise homotopy-based algorithm for L1-regularized linear "
38 "regression (LASSO) and L1+L2-regularized linear regression (Elastic Net)."
39 "\n\n"
40 "This program is able to train a LARS/LASSO/Elastic Net model or load a "
41 "model from file, output regression predictions for a test set, and save "
42 "the trained model to a file. The LARS algorithm is described in more "
43 "detail below:"
44 "\n\n"
45 "Let X be a matrix where each row is a point and each column is a "
46 "dimension, and let y be a vector of targets."
47 "\n\n"
48 "The Elastic Net problem is to solve"
49 "\n\n"
50 " min_beta 0.5 || X * beta - y ||_2^2 + lambda_1 ||beta||_1 +\n"
51 " 0.5 lambda_2 ||beta||_2^2"
52 "\n\n"
53 "If lambda1 > 0 and lambda2 = 0, the problem is the LASSO.\n"
54 "If lambda1 > 0 and lambda2 > 0, the problem is the Elastic Net.\n"
55 "If lambda1 = 0 and lambda2 > 0, the problem is ridge regression.\n"
56 "If lambda1 = 0 and lambda2 = 0, the problem is unregularized linear "
57 "regression."
58 "\n\n"
59 "For efficiency reasons, it is not recommended to use this algorithm with"
60 " " + PRINT_PARAM_STRING("lambda1") + " = 0. In that case, use the "
61 "'linear_regression' program, which implements both unregularized linear "
62 "regression and ridge regression."
63 "\n\n"
64 "To train a LARS/LASSO/Elastic Net model, the " +
65 PRINT_PARAM_STRING("input") + " and " + PRINT_PARAM_STRING("responses") +
66 " parameters must be given. The " + PRINT_PARAM_STRING("lambda1") +
67 ", " + PRINT_PARAM_STRING("lambda2") + ", and " +
68 PRINT_PARAM_STRING("use_cholesky") + " parameters control the training "
69 "options. A trained model can be saved with the " +
70 PRINT_PARAM_STRING("output_model") + ". If no training is desired at all,"
71 " a model can be passed via the " + PRINT_PARAM_STRING("input_model") +
72 " parameter."
73 "\n\n"
74 "The program can also provide predictions for test data using either the "
75 "trained model or the given input model. Test points can be specified with"
76 " the " + PRINT_PARAM_STRING("test") + " parameter. Predicted responses "
77 "to the test points can be saved with the " +
78 PRINT_PARAM_STRING("output_predictions") + " output parameter.");
79
80 // Example.
81 BINDING_EXAMPLE(
82 "For example, the following command trains a model on the data " +
83 PRINT_DATASET("data") + " and responses " + PRINT_DATASET("responses") +
84 " with lambda1 set to 0.4 and lambda2 set to 0 (so, LASSO is being "
85 "solved), and then the model is saved to " + PRINT_MODEL("lasso_model") +
86 ":"
87 "\n\n" +
88 PRINT_CALL("lars", "input", "data", "responses", "responses", "lambda1",
89 0.4, "lambda2", 0.0, "output_model", "lasso_model") +
90 "\n\n"
91 "The following command uses the " + PRINT_MODEL("lasso_model") + " to "
92 "provide predicted responses for the data " + PRINT_DATASET("test") + " "
93 "and save those responses to " + PRINT_DATASET("test_predictions") + ": "
94 "\n\n" +
95 PRINT_CALL("lars", "input_model", "lasso_model", "test", "test",
96 "output_predictions", "test_predictions"));
97
98 // See also...
99 BINDING_SEE_ALSO("@linear_regression", "#linear_regression");
100 BINDING_SEE_ALSO("Least angle regression (pdf)",
101 "http://mlpack.org/papers/lars.pdf");
102 BINDING_SEE_ALSO("mlpack::regression::LARS C++ class documentation",
103 "@doxygen/classmlpack_1_1regression_1_1LARS.html");
104
105 PARAM_TMATRIX_IN("input", "Matrix of covariates (X).", "i");
106 PARAM_MATRIX_IN("responses", "Matrix of responses/observations (y).", "r");
107
108 PARAM_MODEL_IN(LARS, "input_model", "Trained LARS model to use.", "m");
109 PARAM_MODEL_OUT(LARS, "output_model", "Output LARS model.", "M");
110
111 PARAM_TMATRIX_IN("test", "Matrix containing points to regress on (test "
112 "points).", "t");
113
114 PARAM_TMATRIX_OUT("output_predictions", "If --test_file is specified, this "
115 "file is where the predicted responses will be saved.", "o");
116
117 PARAM_DOUBLE_IN("lambda1", "Regularization parameter for l1-norm penalty.", "l",
118 0);
119 PARAM_DOUBLE_IN("lambda2", "Regularization parameter for l2-norm penalty.", "L",
120 0);
121 PARAM_FLAG("use_cholesky", "Use Cholesky decomposition during computation "
122 "rather than explicitly computing the full Gram matrix.", "c");
123
mlpackMain()124 static void mlpackMain()
125 {
126 double lambda1 = IO::GetParam<double>("lambda1");
127 double lambda2 = IO::GetParam<double>("lambda2");
128 bool useCholesky = IO::HasParam("use_cholesky");
129
130 // Check parameters -- make sure everything given makes sense.
131 RequireOnlyOnePassed({ "input", "input_model" }, true);
132 if (IO::HasParam("input"))
133 {
134 RequireOnlyOnePassed({ "responses" }, true, "if input data is specified, "
135 "responses must also be specified");
136 }
137 ReportIgnoredParam({{ "input", false }}, "responses");
138
139 RequireAtLeastOnePassed({ "output_predictions", "output_model" }, false,
140 "no results will be saved");
141 ReportIgnoredParam({{ "test", true }}, "output_predictions");
142
143 LARS* lars;
144 if (IO::HasParam("input"))
145 {
146 // Initialize the object.
147 lars = new LARS(useCholesky, lambda1, lambda2);
148
149 // Load covariates. We can avoid LARS transposing our data by choosing to
150 // not transpose this data (that's why we used PARAM_TMATRIX_IN).
151 mat matX = std::move(IO::GetParam<arma::mat>("input"));
152
153 // Load responses. The responses should be a one-dimensional vector, and it
154 // seems more likely that these will be stored with one response per line
155 // (one per row). So we should not transpose upon loading.
156 mat matY = std::move(IO::GetParam<arma::mat>("responses"));
157
158 // Make sure y is oriented the right way.
159 if (matY.n_cols == 1)
160 matY = trans(matY);
161 if (matY.n_rows > 1)
162 Log::Fatal << "Only one column or row allowed in responses file!" << endl;
163
164 if (matY.n_elem != matX.n_rows)
165 Log::Fatal << "Number of responses must be equal to number of rows of X!"
166 << endl;
167
168 vec beta;
169 arma::rowvec y = std::move(matY);
170 lars->Train(matX, y, beta, false /* do not transpose */);
171 }
172 else // We must have --input_model_file.
173 {
174 lars = IO::GetParam<LARS*>("input_model");
175 }
176
177 if (IO::HasParam("test"))
178 {
179 Log::Info << "Regressing on test points." << endl;
180
181 // Load test points.
182 mat testPoints = std::move(IO::GetParam<arma::mat>("test"));
183
184 // Make sure the dimensionality is right. We haven't transposed, so, we
185 // check n_cols not n_rows.
186 if (testPoints.n_cols != lars->BetaPath().back().n_elem)
187 Log::Fatal << "Dimensionality of test set (" << testPoints.n_cols << ") "
188 << "is not equal to the dimensionality of the model ("
189 << lars->BetaPath().back().n_elem << ")!" << endl;
190
191 arma::rowvec predictions;
192 lars->Predict(testPoints.t(), predictions, false);
193
194 // Save test predictions (one per line).
195 IO::GetParam<arma::mat>("output_predictions") = predictions.t();
196 }
197
198 IO::GetParam<LARS*>("output_model") = lars;
199 }
200