1 /**
2 * @file methods/sparse_coding/sparse_coding.cpp
3 * @author Nishant Mehta
4 *
5 * Implementation of Sparse Coding with Dictionary Learning using l1 (LASSO) or
6 * l1+l2 (Elastic Net) regularization.
7 *
8 * mlpack is free software; you may redistribute it and/or modify it under the
9 * terms of the 3-clause BSD license. You should have received a copy of the
10 * 3-clause BSD license along with mlpack. If not, see
11 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12 */
13 #include "sparse_coding.hpp"
14 #include <mlpack/core/math/lin_alg.hpp>
15
16 namespace mlpack {
17 namespace sparse_coding {
18
SparseCoding(const size_t atoms,const double lambda1,const double lambda2,const size_t maxIterations,const double objTolerance,const double newtonTolerance)19 SparseCoding::SparseCoding(
20 const size_t atoms,
21 const double lambda1,
22 const double lambda2,
23 const size_t maxIterations,
24 const double objTolerance,
25 const double newtonTolerance) :
26 atoms(atoms),
27 lambda1(lambda1),
28 lambda2(lambda2),
29 maxIterations(maxIterations),
30 objTolerance(objTolerance),
31 newtonTolerance(newtonTolerance)
32 {
33 // Nothing to do.
34 }
35
Encode(const arma::mat & data,arma::mat & codes)36 void SparseCoding::Encode(const arma::mat& data, arma::mat& codes)
37 {
38 // When using the Cholesky version of LARS, this is correct even if
39 // lambda2 > 0.
40 arma::mat matGram = trans(dictionary) * dictionary;
41
42 codes.set_size(atoms, data.n_cols);
43 for (size_t i = 0; i < data.n_cols; ++i)
44 {
45 // Report progress.
46 if ((i % 100) == 0)
47 Log::Debug << "Optimization at point " << i << "." << std::endl;
48
49 bool useCholesky = true;
50 regression::LARS lars(useCholesky, matGram, lambda1, lambda2);
51
52 // Create an alias of the code (using the same memory), and then LARS will
53 // place the result directly into that; then we will not need to have an
54 // extra copy.
55 arma::vec code = codes.unsafe_col(i);
56 arma::rowvec responses = data.unsafe_col(i).t();
57 lars.Train(dictionary, responses, code, false);
58 }
59 }
60
61 // Dictionary step for optimization.
OptimizeDictionary(const arma::mat & data,const arma::mat & codes,const arma::uvec & adjacencies)62 double SparseCoding::OptimizeDictionary(const arma::mat& data,
63 const arma::mat& codes,
64 const arma::uvec& adjacencies)
65 {
66 // Count the number of atomic neighbors for each point x^i.
67 arma::uvec neighborCounts = arma::zeros<arma::uvec>(data.n_cols, 1);
68
69 if (adjacencies.n_elem > 0)
70 {
71 // This gets the column index. Intentional integer division.
72 size_t curPointInd = (size_t) (adjacencies(0) / atoms);
73
74 size_t nextColIndex = (curPointInd + 1) * atoms;
75 for (size_t l = 1; l < adjacencies.n_elem; ++l)
76 {
77 // If l no longer refers to an element in this column, advance the column
78 // number accordingly.
79 if (adjacencies(l) >= nextColIndex)
80 {
81 curPointInd = (size_t) (adjacencies(l) / atoms);
82 nextColIndex = (curPointInd + 1) * atoms;
83 }
84
85 ++neighborCounts(curPointInd);
86 }
87 }
88
89 // Handle the case of inactive atoms (atoms not used in the given coding).
90 std::vector<size_t> inactiveAtoms;
91
92 for (size_t j = 0; j < atoms; ++j)
93 {
94 if (arma::accu(codes.row(j) != 0) == 0)
95 inactiveAtoms.push_back(j);
96 }
97
98 const size_t nInactiveAtoms = inactiveAtoms.size();
99 const size_t nActiveAtoms = atoms - nInactiveAtoms;
100
101 // Efficient construction of Z restricted to active atoms.
102 arma::mat matActiveZ;
103 if (nInactiveAtoms > 0)
104 {
105 math::RemoveRows(codes, inactiveAtoms, matActiveZ);
106 }
107
108 if (nInactiveAtoms > 0)
109 {
110 Log::Warn << "There are " << nInactiveAtoms
111 << " inactive atoms. They will be re-initialized randomly.\n";
112 }
113
114 Log::Debug << "Solving Dual via Newton's Method.\n";
115
116 // Solve using Newton's method in the dual - note that the final dot
117 // multiplication with inv(A) seems to be unavoidable. Although more
118 // expensive, the code written this way (we use solve()) should be more
119 // numerically stable than just using inv(A) for everything.
120 arma::vec dualVars = arma::zeros<arma::vec>(nActiveAtoms);
121
122 // vec dualVars = 1e-14 * ones<vec>(nActiveAtoms);
123
124 // Method used by feature sign code - fails miserably here. Perhaps the
125 // MATLAB optimizer fmincon does something clever?
126 // vec dualVars = 10.0 * randu(nActiveAtoms, 1);
127
128 // vec dualVars = diagvec(solve(dictionary, data * trans(codes))
129 // - codes * trans(codes));
130 // for (size_t i = 0; i < dualVars.n_elem; ++i)
131 // if (dualVars(i) < 0)
132 // dualVars(i) = 0;
133
134 bool converged = false;
135
136 // If we have any inactive atoms, we must construct these differently.
137 arma::mat codesXT;
138 arma::mat codesZT;
139
140 if (inactiveAtoms.empty())
141 {
142 codesXT = codes * trans(data);
143 codesZT = codes * trans(codes);
144 }
145 else
146 {
147 codesXT = matActiveZ * trans(data);
148 codesZT = matActiveZ * trans(matActiveZ);
149 }
150
151 double normGradient = 0;
152 double improvement = 0;
153 for (size_t t = 1; (t != maxIterations) && !converged; ++t)
154 {
155 arma::mat A = codesZT + diagmat(dualVars);
156
157 arma::mat matAInvZXT = solve(A, codesXT);
158
159 arma::vec gradient = -arma::sum(arma::square(matAInvZXT), 1);
160 gradient += 1;
161
162 arma::mat hessian = -(-2 * (matAInvZXT * trans(matAInvZXT)) % inv(A));
163
164 arma::vec searchDirection = -solve(hessian, gradient);
165
166 // Armijo line search.
167 const double c = 1e-4;
168 double alpha = 1.0;
169 const double rho = 0.9;
170 double sufficientDecrease = c * dot(gradient, searchDirection);
171
172 // A maxIterations parameter for the Armijo line search may be a good idea,
173 // but it doesn't seem to be causing any problems for now.
174 while (true)
175 {
176 // Calculate objective.
177 double sumDualVars = arma::sum(dualVars);
178 double fOld = -(-trace(trans(codesXT) * matAInvZXT) - sumDualVars);
179 double fNew = -(-trace(trans(codesXT) * solve(codesZT +
180 diagmat(dualVars + alpha * searchDirection), codesXT)) -
181 (sumDualVars + alpha * arma::sum(searchDirection)));
182
183 if (fNew <= fOld + alpha * sufficientDecrease)
184 {
185 searchDirection = alpha * searchDirection;
186 improvement = fOld - fNew;
187 break;
188 }
189
190 alpha *= rho;
191 }
192
193 // Take step and print useful information.
194 dualVars += searchDirection;
195 normGradient = arma::norm(gradient, 2);
196 Log::Debug << "Newton Method iteration " << t << ":" << std::endl;
197 Log::Debug << " Gradient norm: " << std::scientific << normGradient
198 << "." << std::endl;
199 Log::Debug << " Improvement: " << std::scientific << improvement << ".\n";
200
201 if (normGradient < newtonTolerance)
202 converged = true;
203 }
204
205 if (inactiveAtoms.empty())
206 {
207 // Directly update dictionary.
208 dictionary = trans(solve(codesZT + diagmat(dualVars), codesXT));
209 }
210 else
211 {
212 arma::mat activeDictionary = trans(solve(codesZT +
213 diagmat(dualVars), codesXT));
214
215 // Update all atoms.
216 size_t currentInactiveIndex = 0;
217 for (size_t i = 0; i < atoms; ++i)
218 {
219 if (inactiveAtoms[currentInactiveIndex] == i)
220 {
221 // This atom is inactive. Reinitialize it randomly.
222 dictionary.col(i) = (data.col(math::RandInt(data.n_cols)) +
223 data.col(math::RandInt(data.n_cols)) +
224 data.col(math::RandInt(data.n_cols)));
225
226 dictionary.col(i) /= arma::norm(dictionary.col(i), 2);
227
228 // Increment inactive index counter.
229 ++currentInactiveIndex;
230 }
231 else
232 {
233 // Update estimate.
234 dictionary.col(i) = activeDictionary.col(i - currentInactiveIndex);
235 }
236 }
237 }
238
239 return normGradient;
240 }
241
242 // Project each atom of the dictionary back into the unit ball (if necessary).
ProjectDictionary()243 void SparseCoding::ProjectDictionary()
244 {
245 for (size_t j = 0; j < atoms; ++j)
246 {
247 double atomNorm = arma::norm(dictionary.col(j), 2);
248 if (atomNorm > 1)
249 {
250 Log::Info << "Norm of atom " << j << " exceeds 1 (" << std::scientific
251 << atomNorm << "). Shrinking...\n";
252 dictionary.col(j) /= atomNorm;
253 }
254 }
255 }
256
257 // Compute the objective function.
Objective(const arma::mat & data,const arma::mat & codes) const258 double SparseCoding::Objective(const arma::mat& data, const arma::mat& codes)
259 const
260 {
261 double l11NormZ = arma::sum(arma::sum(arma::abs(codes)));
262 double froNormResidual = arma::norm(data - (dictionary * codes), "fro");
263
264 if (lambda2 > 0)
265 {
266 double froNormZ = arma::norm(codes, "fro");
267 return 0.5 * (std::pow(froNormResidual, 2.0) + (lambda2 *
268 std::pow(froNormZ, 2.0))) + (lambda1 * l11NormZ);
269 }
270 else // It can be simpler.
271 {
272 return 0.5 * std::pow(froNormResidual, 2.0) + lambda1 * l11NormZ;
273 }
274 }
275
276 } // namespace sparse_coding
277 } // namespace mlpack
278