1 // Ceres Solver - A fast non-linear least squares minimizer
2 // Copyright 2015 Google Inc. All rights reserved.
3 // http://ceres-solver.org/
4 //
5 // Redistribution and use in source and binary forms, with or without
6 // modification, are permitted provided that the following conditions are met:
7 //
8 // * Redistributions of source code must retain the above copyright notice,
9 //   this list of conditions and the following disclaimer.
10 // * Redistributions in binary form must reproduce the above copyright notice,
11 //   this list of conditions and the following disclaimer in the documentation
12 //   and/or other materials provided with the distribution.
13 // * Neither the name of Google Inc. nor the names of its contributors may be
14 //   used to endorse or promote products derived from this software without
15 //   specific prior written permission.
16 //
17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27 // POSSIBILITY OF SUCH DAMAGE.
28 //
29 // Author: sameeragarwal@google.com (Sameer Agarwal)
30 
31 #include "ceres/levenberg_marquardt_strategy.h"
32 
33 #include <memory>
34 
35 #include "ceres/internal/eigen.h"
36 #include "ceres/linear_solver.h"
37 #include "ceres/trust_region_strategy.h"
38 #include "glog/logging.h"
39 #include "gmock/gmock.h"
40 #include "gmock/mock-log.h"
41 #include "gtest/gtest.h"
42 
43 using testing::_;
44 using testing::AllOf;
45 using testing::AnyNumber;
46 using testing::HasSubstr;
47 using testing::ScopedMockLog;
48 
49 namespace ceres {
50 namespace internal {
51 
52 const double kTolerance = 1e-16;
53 
54 // Linear solver that takes as input a vector and checks that the
55 // caller passes the same vector as LinearSolver::PerSolveOptions.D.
56 class RegularizationCheckingLinearSolver : public DenseSparseMatrixSolver {
57  public:
RegularizationCheckingLinearSolver(const int num_cols,const double * diagonal)58   RegularizationCheckingLinearSolver(const int num_cols, const double* diagonal)
59       : num_cols_(num_cols), diagonal_(diagonal) {}
60 
~RegularizationCheckingLinearSolver()61   virtual ~RegularizationCheckingLinearSolver() {}
62 
63  private:
SolveImpl(DenseSparseMatrix * A,const double * b,const LinearSolver::PerSolveOptions & per_solve_options,double * x)64   LinearSolver::Summary SolveImpl(
65       DenseSparseMatrix* A,
66       const double* b,
67       const LinearSolver::PerSolveOptions& per_solve_options,
68       double* x) final {
69     CHECK(per_solve_options.D != nullptr);
70     for (int i = 0; i < num_cols_; ++i) {
71       EXPECT_NEAR(per_solve_options.D[i], diagonal_[i], kTolerance)
72           << i << " " << per_solve_options.D[i] << " " << diagonal_[i];
73     }
74     return LinearSolver::Summary();
75   }
76 
77   const int num_cols_;
78   const double* diagonal_;
79 };
80 
TEST(LevenbergMarquardtStrategy,AcceptRejectStepRadiusScaling)81 TEST(LevenbergMarquardtStrategy, AcceptRejectStepRadiusScaling) {
82   TrustRegionStrategy::Options options;
83   options.initial_radius = 2.0;
84   options.max_radius = 20.0;
85   options.min_lm_diagonal = 1e-8;
86   options.max_lm_diagonal = 1e8;
87 
88   // We need a non-null pointer here, so anything should do.
89   std::unique_ptr<LinearSolver> linear_solver(
90       new RegularizationCheckingLinearSolver(0, NULL));
91   options.linear_solver = linear_solver.get();
92 
93   LevenbergMarquardtStrategy lms(options);
94   EXPECT_EQ(lms.Radius(), options.initial_radius);
95   lms.StepRejected(0.0);
96   EXPECT_EQ(lms.Radius(), 1.0);
97   lms.StepRejected(-1.0);
98   EXPECT_EQ(lms.Radius(), 0.25);
99   lms.StepAccepted(1.0);
100   EXPECT_EQ(lms.Radius(), 0.25 * 3.0);
101   lms.StepAccepted(1.0);
102   EXPECT_EQ(lms.Radius(), 0.25 * 3.0 * 3.0);
103   lms.StepAccepted(0.25);
104   EXPECT_EQ(lms.Radius(), 0.25 * 3.0 * 3.0 / 1.125);
105   lms.StepAccepted(1.0);
106   EXPECT_EQ(lms.Radius(), 0.25 * 3.0 * 3.0 / 1.125 * 3.0);
107   lms.StepAccepted(1.0);
108   EXPECT_EQ(lms.Radius(), 0.25 * 3.0 * 3.0 / 1.125 * 3.0 * 3.0);
109   lms.StepAccepted(1.0);
110   EXPECT_EQ(lms.Radius(), options.max_radius);
111 }
112 
TEST(LevenbergMarquardtStrategy,CorrectDiagonalToLinearSolver)113 TEST(LevenbergMarquardtStrategy, CorrectDiagonalToLinearSolver) {
114   Matrix jacobian(2, 3);
115   jacobian.setZero();
116   jacobian(0, 0) = 0.0;
117   jacobian(0, 1) = 1.0;
118   jacobian(1, 1) = 1.0;
119   jacobian(0, 2) = 100.0;
120 
121   double residual = 1.0;
122   double x[3];
123   DenseSparseMatrix dsm(jacobian);
124 
125   TrustRegionStrategy::Options options;
126   options.initial_radius = 2.0;
127   options.max_radius = 20.0;
128   options.min_lm_diagonal = 1e-2;
129   options.max_lm_diagonal = 1e2;
130 
131   double diagonal[3];
132   diagonal[0] = options.min_lm_diagonal;
133   diagonal[1] = 2.0;
134   diagonal[2] = options.max_lm_diagonal;
135   for (int i = 0; i < 3; ++i) {
136     diagonal[i] = sqrt(diagonal[i] / options.initial_radius);
137   }
138 
139   RegularizationCheckingLinearSolver linear_solver(3, diagonal);
140   options.linear_solver = &linear_solver;
141 
142   LevenbergMarquardtStrategy lms(options);
143   TrustRegionStrategy::PerSolveOptions pso;
144 
145   {
146     ScopedMockLog log;
147     EXPECT_CALL(log, Log(_, _, _)).Times(AnyNumber());
148     // This using directive is needed get around the fact that there
149     // are versions of glog which are not in the google namespace.
150     using namespace google;
151 
152 #if defined(_MSC_VER)
153     // Use GLOG_WARNING to support MSVC if GLOG_NO_ABBREVIATED_SEVERITIES
154     // is defined.
155     EXPECT_CALL(log,
156                 Log(GLOG_WARNING, _, HasSubstr("Failed to compute a step")));
157 #else
158     EXPECT_CALL(log,
159                 Log(google::WARNING, _, HasSubstr("Failed to compute a step")));
160 #endif
161 
162     TrustRegionStrategy::Summary summary =
163         lms.ComputeStep(pso, &dsm, &residual, x);
164     EXPECT_EQ(summary.termination_type, LINEAR_SOLVER_FAILURE);
165   }
166 }
167 
168 }  // namespace internal
169 }  // namespace ceres
170