1 /**
2 * @file tests/sparse_coding_test.cpp
3 *
4 * Test for Sparse Coding
5 *
6 * mlpack is free software; you may redistribute it and/or modify it under the
7 * terms of the 3-clause BSD license. You should have received a copy of the
8 * 3-clause BSD license along with mlpack. If not, see
9 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
10 */
11
12 // Note: We don't use BOOST_REQUIRE_CLOSE in the code below because we need
13 // to use FPC_WEAK, and it's not at all intuitive how to do that.
14
15 #include <mlpack/core.hpp>
16 #include <mlpack/methods/sparse_coding/sparse_coding.hpp>
17
18 #include "catch.hpp"
19 #include "test_catch_tools.hpp"
20 #include "serialization_catch.hpp"
21
22 using namespace arma;
23 using namespace mlpack;
24 using namespace mlpack::regression;
25 using namespace mlpack::sparse_coding;
26
SCVerifyCorrectness(vec beta,vec errCorr,double lambda)27 void SCVerifyCorrectness(vec beta, vec errCorr, double lambda)
28 {
29 const double tol = 1e-12;
30 size_t nDims = beta.n_elem;
31 for (size_t j = 0; j < nDims; ++j)
32 {
33 if (beta(j) == 0)
34 {
35 // Make sure that errCorr(j) <= lambda.
36 REQUIRE(std::max(fabs(errCorr(j)) - lambda, 0.0) ==
37 Approx(0.0).margin(tol));
38 }
39 else if (beta(j) < 0)
40 {
41 // Make sure that errCorr(j) == lambda.
42 REQUIRE(errCorr(j) - lambda == Approx(0.0).margin(tol));
43 }
44 else // beta(j) > 0.
45 {
46 // Make sure that errCorr(j) == -lambda.
47 REQUIRE(errCorr(j) + lambda == Approx(0.0).margin(tol));
48 }
49 }
50 }
51
52 TEST_CASE("SparseCodingTestCodingStepLasso", "[SparseCodingTest]")
53 {
54 double lambda1 = 0.1;
55 uword nAtoms = 25;
56
57 mat X;
58 X.load("mnist_first250_training_4s_and_9s.arm");
59 uword nPoints = X.n_cols;
60
61 // Normalize each point since these are images.
62 for (uword i = 0; i < nPoints; ++i)
63 {
64 X.col(i) /= norm(X.col(i), 2);
65 }
66
67 SparseCoding sc(nAtoms, lambda1);
68 mat Z;
69 DataDependentRandomInitializer::Initialize(X, 25, sc.Dictionary());
70 sc.Encode(X, Z);
71
72 mat D = sc.Dictionary();
73
74 for (uword i = 0; i < nPoints; ++i)
75 {
76 vec errCorr = trans(D) * (D * Z.unsafe_col(i) - X.unsafe_col(i));
77 SCVerifyCorrectness(Z.unsafe_col(i), errCorr, lambda1);
78 }
79 }
80
81 TEST_CASE("SparseCodingTestCodingStepElasticNet", "[SparseCodingTest]")
82 {
83 double lambda1 = 0.1;
84 double lambda2 = 0.2;
85 uword nAtoms = 25;
86
87 mat X;
88 X.load("mnist_first250_training_4s_and_9s.arm");
89 uword nPoints = X.n_cols;
90
91 // Normalize each point since these are images.
92 for (uword i = 0; i < nPoints; ++i)
93 X.col(i) /= norm(X.col(i), 2);
94
95 SparseCoding sc(nAtoms, lambda1, lambda2);
96 mat Z;
97 DataDependentRandomInitializer::Initialize(X, 25, sc.Dictionary());
98 sc.Encode(X, Z);
99
100 mat D = sc.Dictionary();
101
102 for (uword i = 0; i < nPoints; ++i)
103 {
104 vec errCorr =
105 (trans(D) * D + lambda2 * eye(nAtoms, nAtoms)) * Z.unsafe_col(i)
106 - trans(D) * X.unsafe_col(i);
107
108 SCVerifyCorrectness(Z.unsafe_col(i), errCorr, lambda1);
109 }
110 }
111
112 TEST_CASE("SparseCodingTestDictionaryStep", "[SparseCodingTest]")
113 {
114 const double tol = 1e-6;
115
116 double lambda1 = 0.1;
117 uword nAtoms = 25;
118
119 mat X;
120 X.load("mnist_first250_training_4s_and_9s.arm");
121 uword nPoints = X.n_cols;
122
123 // Normalize each point since these are images.
124 for (uword i = 0; i < nPoints; ++i)
125 X.col(i) /= norm(X.col(i), 2);
126
127 SparseCoding sc(nAtoms, lambda1, 0.0, 0, 0.01, tol);
128 mat Z;
129 DataDependentRandomInitializer::Initialize(X, 25, sc.Dictionary());
130 sc.Encode(X, Z);
131
132 mat D = sc.Dictionary();
133
134 uvec adjacencies = find(Z);
135 double normGradient = sc.OptimizeDictionary(X, Z, adjacencies);
136
137 REQUIRE(normGradient == Approx(0.0).margin(tol));
138 }
139
140 TEST_CASE("SerializationTest", "[SparseCodingTest]")
141 {
142 mat X = randu<mat>(100, 100);
143 size_t nAtoms = 25;
144
145 SparseCoding sc(nAtoms, 0.05, 0.1);
146 sc.Train(X);
147
148 mat Y = randu<mat>(100, 200);
149 mat codes;
150 sc.Encode(Y, codes);
151
152 SparseCoding scXml(50, 0.01), scText(nAtoms, 0.05), scBinary(0, 0.0);
153 SerializeObjectAll(sc, scXml, scText, scBinary);
154
155 CheckMatrices(sc.Dictionary(), scXml.Dictionary(), scText.Dictionary(),
156 scBinary.Dictionary());
157
158 mat xmlCodes, textCodes, binaryCodes;
159 scXml.Encode(Y, xmlCodes);
160 scText.Encode(Y, textCodes);
161 scBinary.Encode(Y, binaryCodes);
162
163 CheckMatrices(codes, xmlCodes, textCodes, binaryCodes);
164
165 // Check the parameters, too.
166 REQUIRE(sc.Atoms() == scXml.Atoms());
167 REQUIRE(sc.Atoms() == scText.Atoms());
168 REQUIRE(sc.Atoms() == scBinary.Atoms());
169
170 REQUIRE(sc.Lambda1() == Approx(scXml.Lambda1()).epsilon(1e-7));
171 REQUIRE(sc.Lambda1() == Approx(scText.Lambda1()).epsilon(1e-7));
172 REQUIRE(sc.Lambda1() == Approx(scBinary.Lambda1()).epsilon(1e-7));
173
174 REQUIRE(sc.Lambda2() == Approx(scXml.Lambda2()).epsilon(1e-7));
175 REQUIRE(sc.Lambda2() == Approx(scText.Lambda2()).epsilon(1e-7));
176 REQUIRE(sc.Lambda2() == Approx(scBinary.Lambda2()).epsilon(1e-7));
177
178 REQUIRE(sc.MaxIterations() == scXml.MaxIterations());
179 REQUIRE(sc.MaxIterations() == scText.MaxIterations());
180 REQUIRE(sc.MaxIterations() == scBinary.MaxIterations());
181
182 REQUIRE(sc.ObjTolerance() == Approx(scXml.ObjTolerance()).epsilon(1e-7));
183 REQUIRE(sc.ObjTolerance() == Approx(scText.ObjTolerance()).epsilon(1e-7));
184 REQUIRE(sc.ObjTolerance() == Approx(scBinary.ObjTolerance()).epsilon(1e-7));
185
186 REQUIRE(sc.NewtonTolerance() ==
187 Approx(scXml.NewtonTolerance()).epsilon(1e-7));
188 REQUIRE(sc.NewtonTolerance() ==
189 Approx(scText.NewtonTolerance()).epsilon(1e-7));
190 REQUIRE(sc.NewtonTolerance() ==
191 Approx(scBinary.NewtonTolerance()).epsilon(1e-7));
192 }
193
194 /**
195 * Test that SparseCoding::Train() returns finite final objective value.
196 */
197 TEST_CASE("SparseCodingTrainReturnObjective", "[SparseCodingTest]")
198 {
199 const double tol = 1e-6;
200
201 double lambda1 = 0.1;
202 uword nAtoms = 25;
203
204 mat X;
205 X.load("mnist_first250_training_4s_and_9s.arm");
206 uword nPoints = X.n_cols;
207
208 // Normalize each point since these are images.
209 for (uword i = 0; i < nPoints; ++i)
210 X.col(i) /= norm(X.col(i), 2);
211
212 SparseCoding sc(nAtoms, lambda1, 0.0, 0, 0.01, tol);
213 double objVal = sc.Train(X);
214
215 REQUIRE(std::isfinite(objVal) == true);
216 }
217