1 // This is core/vnl/algo/vnl_lbfgs.cxx
2 //:
3 // \file
4 //
5 // \author Andrew W. Fitzgibbon, Oxford RRG
6 // \date 22 Aug 99
7 //
8 //-----------------------------------------------------------------------------
9
10 #include <cmath>
11 #include <iostream>
12 #include <iomanip>
13 #include "vnl_lbfgs.h"
14
15 #include <vnl/algo/vnl_netlib.h> // lbfgs_()
16
17 //: Default constructor.
18 // memory is set to 5, line_search_accuracy to 0.9.
19 // Calls init_parameters
vnl_lbfgs()20 vnl_lbfgs::vnl_lbfgs()
21
22 {
23 init_parameters();
24 }
25
26 //: Constructor. f is the cost function to be minimized.
27 // Calls init_parameters
vnl_lbfgs(vnl_cost_function & f)28 vnl_lbfgs::vnl_lbfgs(vnl_cost_function & f)
29 : f_(&f)
30 {
31 init_parameters();
32 }
33
34 //: Called by constructors.
35 // Memory is set to 5, line_search_accuracy to 0.9, default_step_length to 1.
36 void
init_parameters()37 vnl_lbfgs::init_parameters()
38 {
39 memory = 5;
40 line_search_accuracy = 0.9;
41 default_step_length = 1.0;
42 }
43
44 bool
minimize(vnl_vector<double> & x)45 vnl_lbfgs::minimize(vnl_vector<double> & x)
46 {
47 // Local variables
48 // The driver for vnl_lbfgs must always declare LB2 as EXTERNAL
49
50 long n = f_->get_number_of_unknowns();
51 long m = memory; // The number of basis vectors to remember.
52
53 // Create an instance of the lbfgs global data to pass as an
54 // argument. It must persist through all calls in this
55 // minimization.
56 v3p_netlib_lbfgs_global_t lbfgs_global;
57 v3p_netlib_lbfgs_init(&lbfgs_global);
58
59 long iprint[2] = { 1, 0 };
60 vnl_vector<double> g(n);
61
62 // Workspace
63 vnl_vector<double> diag(n);
64
65 vnl_vector<double> w(n * (2 * m + 1) + 2 * m);
66
67 if (verbose_)
68 std::cerr << "vnl_lbfgs: n = " << n << ", memory = " << m << ", Workspace = " << w.size() << "[ "
69 << (w.size() / 128.0 / 1024.0) << " MB], ErrorScale = " << f_->reported_error(1)
70 << ", xnorm = " << x.magnitude() << std::endl;
71
72 bool we_trace = (verbose_ && !trace);
73
74 if (we_trace)
75 std::cerr << "vnl_lbfgs: ";
76
77 double best_f = 0;
78 vnl_vector<double> best_x;
79
80 bool ok;
81 this->num_evaluations_ = 0;
82 this->num_iterations_ = 0;
83 long iflag = 0;
84 while (true)
85 {
86 // We do not wish to provide the diagonal matrices Hk0, and therefore set DIAGCO to FALSE.
87 v3p_netlib_logical diagco = false;
88
89 // Set these every iter in case user changes them to bail out
90 double eps = gtol; // Gradient tolerance
91 double local_xtol = 1e-16;
92 lbfgs_global.gtol = line_search_accuracy; // set to 0.1 for huge problems or cheap functions
93 lbfgs_global.stpinit = default_step_length;
94
95 // Call function
96 double f;
97 f_->compute(x, &f, &g);
98 if (this->num_evaluations_ == 0)
99 {
100 this->start_error_ = f;
101 best_x = x;
102 best_f = f;
103 }
104 else if (f < best_f)
105 {
106 best_x = x;
107 best_f = f;
108 }
109
110 #define print_(i, a, b, c, d) \
111 std::cerr << std::setw(6) << (i) << ' ' << std::setw(20) << (a) << ' ' << std::setw(20) << (b) << ' ' \
112 << std::setw(20) << (c) << ' ' << std::setw(20) << (d) << '\n'
113
114 if (check_derivatives_)
115 {
116 std::cerr << "vnl_lbfgs: f = " << f_->reported_error(f) << ", computing FD gradient\n";
117 vnl_vector<double> fdg = f_->fdgradf(x);
118 if (verbose_)
119 {
120 int l = n;
121 int limit = 100;
122 int limit_tail = 10;
123 if (l > limit + limit_tail)
124 {
125 std::cerr << " [ Showing only first " << limit << " components ]\n";
126 l = limit;
127 }
128 print_("i", "x", "g", "fdg", "dg");
129 print_("-", "-", "-", "---", "--");
130 for (int i = 0; i < l; ++i)
131 print_(i, x[i], g[i], fdg[i], g[i] - fdg[i]);
132 if (n > limit)
133 {
134 std::cerr << " ...\n";
135 for (int i = n - limit_tail; i < n; ++i)
136 print_(i, x[i], g[i], fdg[i], g[i] - fdg[i]);
137 }
138 }
139 std::cerr << " ERROR = " << (fdg - g).squared_magnitude() / std::sqrt(double(n)) << "\n";
140 }
141
142 iprint[0] = trace ? 1 : -1; // -1 no o/p, 0 start and end, 1 every iter.
143 iprint[1] = 0; // 1 prints X and G
144 v3p_netlib_lbfgs_(&n,
145 &m,
146 x.data_block(),
147 &f,
148 g.data_block(),
149 &diagco,
150 diag.data_block(),
151 iprint,
152 &eps,
153 &local_xtol,
154 w.data_block(),
155 &iflag,
156 &lbfgs_global);
157
158 this->report_eval(f);
159
160 if (this->report_iter())
161 {
162 failure_code_ = FAILED_USER_REQUEST;
163 ok = false;
164 x = best_x;
165 break;
166 }
167
168 if (we_trace)
169 std::cerr << iflag << ":" << f_->reported_error(f) << " ";
170
171 if (iflag == 0)
172 {
173 // Successful return
174 this->end_error_ = f;
175 ok = true;
176 x = best_x;
177 break;
178 }
179
180 if (iflag < 0)
181 {
182 // Netlib routine lbfgs failed
183 std::cerr << "vnl_lbfgs: Error. Netlib routine lbfgs failed.\n";
184 ok = false;
185 x = best_x;
186 break;
187 }
188
189 if (this->num_evaluations_ > get_max_function_evals())
190 {
191 failure_code_ = TOO_MANY_ITERATIONS;
192 ok = false;
193 x = best_x;
194 break;
195 }
196 }
197 if (we_trace)
198 std::cerr << "done\n";
199
200 return ok;
201 }
202