1 /**
2  * @file spsa_test.cpp
3  * @author N Rajiv Vaidyanathan
4  * @author Marcus Edel
5  *
6  * Test file for the SPSA optimizer.
7  *
8  * ensmallen is free software; you may redistribute it and/or modify it under
9  * the terms of the 3-clause BSD license.  You should have received a copy of
10  * the 3-clause BSD license along with ensmallen.  If not, see
11  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12  */
13 
14 #include <ensmallen.hpp>
15 #include "catch.hpp"
16 #include "test_function_tools.hpp"
17 
18 using namespace arma;
19 using namespace ens;
20 using namespace ens::test;
21 
22 /**
23  * Test the SPSA optimizer on the Sphere function.
24  */
25 TEST_CASE("SPSASphereFunctionTest", "[SPSATest]")
26 {
27   SPSA optimizer(0.1, 0.102, 0.16, 0.3, 100000, 0);
28   FunctionTest<SphereFunction>(optimizer, 1.0, 0.1);
29 }
30 
31 /**
32  * Test the SPSA optimizer on the Sphere function using arma::fmat.
33  */
34 TEST_CASE("SPSASphereFunctionFMatTest", "[SPSATest]")
35 {
36   SPSA optimizer(0.1, 0.102, 0.16, 0.3, 100000, 0);
37   FunctionTest<SphereFunction, arma::fmat>(optimizer, 1.0, 0.1);
38 }
39 
40 /**
41  * Test the SPSA optimizer on the Sphere function using arma::sp_mat.
42  */
43 TEST_CASE("SPSASphereFunctionSpMatTest", "[SPSATest]")
44 {
45   SPSA optimizer(0.1, 0.102, 0.16, 0.3, 100000, 0);
46   FunctionTest<SphereFunction, arma::sp_mat>(optimizer, 1.0, 0.1);
47 }
48 
49 /**
50  * Test the SPSA optimizer on the Matyas function.
51  */
52 TEST_CASE("SPSAMatyasFunctionTest", "[SPSATest]")
53 {
54   SPSA optimizer(0.1, 0.102, 0.16, 0.3, 100000, 0);
55   FunctionTest<MatyasFunction>(optimizer, 0.1, 0.01);
56 }
57 
58 /**
59  * Run SPSA on logistic regression and make sure the results are acceptable.
60  */
61 TEST_CASE("SPSALogisticRegressionTest", "[SPSATest]")
62 {
63   // We allow 10 trials, because SPSA is definitely not guaranteed to
64   // converge.
65   SPSA optimizer(0.5, 0.102, 0.002, 0.3, 5000, 1e-8);
66   LogisticRegressionFunctionTest(optimizer, 0.003, 0.006, 10);
67 }
68