1 // Copyright (C) 2010  Davis E. King (davis@dlib.net)
2 // License: Boost Software License   See LICENSE.txt for the full license.
3 
4 
5 #include <dlib/optimization.h>
6 #include "optimization_test_functions.h"
7 #include <sstream>
8 #include <string>
9 #include <cstdlib>
10 #include <ctime>
11 #include <vector>
12 #include "../rand.h"
13 
14 #include "tester.h"
15 
16 
17 namespace
18 {
19 
20     using namespace test;
21     using namespace dlib;
22     using namespace std;
23     using namespace dlib::test_functions;
24 
25     logger dlog("test.least_squares");
26 
27 // ----------------------------------------------------------------------------------------
28 
test_with_chebyquad()29     void test_with_chebyquad()
30     {
31         print_spinner();
32         {
33             matrix<double,0,1> ch;
34 
35             ch = chebyquad_start(2);
36 
37             solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
38                                 chebyquad_residual,
39                                 derivative(chebyquad_residual),
40                                 range(0,ch.size()-1),
41                                 ch);
42 
43             dlog << LINFO << "chebyquad 2 obj: " << chebyquad(ch);
44             dlog << LINFO << "chebyquad 2 der: " << length(chebyquad_derivative(ch));
45             dlog << LINFO << "chebyquad 2 error: " << length(ch - chebyquad_solution(2));
46 
47             DLIB_TEST(length(ch - chebyquad_solution(2)) < 1e-5);
48 
49         }
50         {
51             matrix<double,0,1> ch;
52 
53             ch = chebyquad_start(2);
54 
55             solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
56                                 chebyquad_residual,
57                                 derivative(chebyquad_residual),
58                                 range(0,ch.size()-1),
59                                 ch);
60 
61             dlog << LINFO << "LM chebyquad 2 obj: " << chebyquad(ch);
62             dlog << LINFO << "LM chebyquad 2 der: " << length(chebyquad_derivative(ch));
63             dlog << LINFO << "LM chebyquad 2 error: " << length(ch - chebyquad_solution(2));
64 
65             DLIB_TEST(length(ch - chebyquad_solution(2)) < 1e-5);
66 
67         }
68 
69         print_spinner();
70         {
71             matrix<double,2,1> ch;
72 
73             ch = chebyquad_start(2);
74 
75             solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
76                                 chebyquad_residual,
77                                 derivative(chebyquad_residual),
78                                 range(0,ch.size()-1),
79                                 ch);
80 
81             dlog << LINFO << "chebyquad 2 obj: " << chebyquad(ch);
82             dlog << LINFO << "chebyquad 2 der: " << length(chebyquad_derivative(ch));
83             dlog << LINFO << "chebyquad 2 error: " << length(ch - chebyquad_solution(2));
84 
85             DLIB_TEST(length(ch - chebyquad_solution(2)) < 1e-5);
86 
87         }
88         print_spinner();
89         {
90             matrix<double,2,1> ch;
91 
92             ch = chebyquad_start(2);
93 
94             solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
95                                 chebyquad_residual,
96                                 derivative(chebyquad_residual),
97                                 range(0,ch.size()-1),
98                                 ch);
99 
100             dlog << LINFO << "LM chebyquad 2 obj: " << chebyquad(ch);
101             dlog << LINFO << "LM chebyquad 2 der: " << length(chebyquad_derivative(ch));
102             dlog << LINFO << "LM chebyquad 2 error: " << length(ch - chebyquad_solution(2));
103 
104             DLIB_TEST(length(ch - chebyquad_solution(2)) < 1e-5);
105 
106         }
107 
108         print_spinner();
109         {
110             matrix<double,0,1> ch;
111 
112             ch = chebyquad_start(4);
113 
114             solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
115                                 chebyquad_residual,
116                                 derivative(chebyquad_residual),
117                                 range(0,ch.size()-1),
118                                 ch);
119 
120             dlog << LINFO << "chebyquad 4 obj: " << chebyquad(ch);
121             dlog << LINFO << "chebyquad 4 der: " << length(chebyquad_derivative(ch));
122             dlog << LINFO << "chebyquad 4 error: " << length(ch - chebyquad_solution(4));
123 
124             DLIB_TEST(length(ch - chebyquad_solution(4)) < 1e-5);
125 
126         }
127         print_spinner();
128         {
129             matrix<double,0,1> ch;
130 
131             ch = chebyquad_start(4);
132 
133             solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
134                                 chebyquad_residual,
135                                 derivative(chebyquad_residual),
136                                 range(0,ch.size()-1),
137                                 ch);
138 
139             dlog << LINFO << "LM chebyquad 4 obj: " << chebyquad(ch);
140             dlog << LINFO << "LM chebyquad 4 der: " << length(chebyquad_derivative(ch));
141             dlog << LINFO << "LM chebyquad 4 error: " << length(ch - chebyquad_solution(4));
142 
143             DLIB_TEST(length(ch - chebyquad_solution(4)) < 1e-5);
144 
145         }
146 
147 
148         print_spinner();
149         {
150             matrix<double,0,1> ch;
151 
152             ch = chebyquad_start(6);
153 
154             solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
155                                 chebyquad_residual,
156                                 derivative(chebyquad_residual),
157                                 range(0,ch.size()-1),
158                                 ch);
159 
160             dlog << LINFO << "chebyquad 6 obj: " << chebyquad(ch);
161             dlog << LINFO << "chebyquad 6 der: " << length(chebyquad_derivative(ch));
162             dlog << LINFO << "chebyquad 6 error: " << length(ch - chebyquad_solution(6));
163 
164             // the ch variable contains a permutation of what is in chebyquad_solution(6).
165             // Apparently there is more than one minimum?.  Just check that the objective
166             // goes to zero.
167             DLIB_TEST(chebyquad(ch) < 1e-10);
168 
169         }
170         print_spinner();
171         {
172             matrix<double,0,1> ch;
173 
174             ch = chebyquad_start(6);
175 
176             solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
177                                 chebyquad_residual,
178                                 derivative(chebyquad_residual),
179                                 range(0,ch.size()-1),
180                                 ch);
181 
182             dlog << LINFO << "LM chebyquad 6 obj: " << chebyquad(ch);
183             dlog << LINFO << "LM chebyquad 6 der: " << length(chebyquad_derivative(ch));
184             dlog << LINFO << "LM chebyquad 6 error: " << length(ch - chebyquad_solution(6));
185 
186             DLIB_TEST(chebyquad(ch) < 1e-10);
187 
188         }
189 
190 
191         print_spinner();
192         {
193             matrix<double,0,1> ch;
194 
195             ch = chebyquad_start(8);
196 
197             solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
198                                 chebyquad_residual,
199                                 derivative(chebyquad_residual),
200                                 range(0,ch.size()-1),
201                                 ch);
202 
203             dlog << LINFO << "chebyquad 8 obj: " << chebyquad(ch);
204             dlog << LINFO << "chebyquad 8 der: " << length(chebyquad_derivative(ch));
205             dlog << LINFO << "chebyquad 8 error: " << length(ch - chebyquad_solution(8));
206 
207             DLIB_TEST(length(ch - chebyquad_solution(8)) < 1e-5);
208 
209         }
210         print_spinner();
211         {
212             matrix<double,0,1> ch;
213 
214             ch = chebyquad_start(8);
215 
216             solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
217                                 chebyquad_residual,
218                                 derivative(chebyquad_residual),
219                                 range(0,ch.size()-1),
220                                 ch);
221 
222             dlog << LINFO << "LM chebyquad 8 obj: " << chebyquad(ch);
223             dlog << LINFO << "LM chebyquad 8 der: " << length(chebyquad_derivative(ch));
224             dlog << LINFO << "LM chebyquad 8 error: " << length(ch - chebyquad_solution(8));
225 
226             DLIB_TEST(length(ch - chebyquad_solution(8)) < 1e-5);
227 
228         }
229     }
230 
231 // ----------------------------------------------------------------------------------------
232 
test_with_brown()233     void test_with_brown()
234     {
235         print_spinner();
236         {
237             matrix<double,4,1> ch;
238 
239             ch = brown_start();
240 
241             solve_least_squares(objective_delta_stop_strategy(1e-13, 300),
242                                 brown_residual,
243                                 derivative(brown_residual),
244                                 range(1,20),
245                                 ch);
246 
247             dlog << LINFO << "brown obj: " << brown(ch);
248             dlog << LINFO << "brown der: " << length(brown_derivative(ch));
249             dlog << LINFO << "brown error: " << length(ch - brown_solution());
250 
251             DLIB_TEST_MSG(length(ch - brown_solution()) < 1e-5,length(ch - brown_solution()) );
252 
253         }
254         print_spinner();
255         {
256             matrix<double,4,1> ch;
257 
258             ch = brown_start();
259 
260             solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
261                                 brown_residual,
262                                 derivative(brown_residual),
263                                 range(1,20),
264                                 ch);
265 
266             dlog << LINFO << "LM brown obj: " << brown(ch);
267             dlog << LINFO << "LM brown der: " << length(brown_derivative(ch));
268             dlog << LINFO << "LM brown error: " << length(ch - brown_solution());
269 
270             DLIB_TEST(length(ch - brown_solution()) < 1e-5);
271 
272         }
273     }
274 
275 // ----------------------------------------------------------------------------------------
276 
277 // These functions are declared here because wrapping the real rosen functions in this
278 // way avoids triggering a bug in visual studio 2005 which prevents this code from compiling.
rosen_residual_double(int i,const matrix<double,2,1> & m)279     double rosen_residual_double (int i, const matrix<double,2,1>& m)
280     { return rosen_residual(i,m); }
rosen_residual_float(int i,const matrix<float,2,1> & m)281     float rosen_residual_float (int i, const matrix<float,2,1>& m)
282     { return rosen_residual(i,m); }
283 
rosen_residual_derivative_double(int i,const matrix<double,2,1> & m)284     matrix<double,2,1> rosen_residual_derivative_double (int i, const matrix<double,2,1>& m)
285     { return rosen_residual_derivative(i,m); }
286     /*
287     matrix<float,2,1> rosen_residual_derivative_float (int i, const matrix<float,2,1>& m)
288     { return rosen_residual_derivative(i,m); }
289     */
290 
rosen_big_residual_double(int i,const matrix<double,2,1> & m)291     double rosen_big_residual_double (int i, const matrix<double,2,1>& m)
292     { return rosen_big_residual(i,m); }
293 
294 // ----------------------------------------------------------------------------------------
295 
test_with_rosen()296     void test_with_rosen()
297     {
298 
299         print_spinner();
300         {
301             matrix<double,2,1> ch;
302 
303             ch = rosen_start<double>();
304 
305             solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
306                                 rosen_residual_double,
307                                 rosen_residual_derivative_double,
308                                 range(1,20),
309                                 ch);
310 
311             dlog << LINFO << "rosen obj: " << rosen(ch);
312             dlog << LINFO << "rosen error: " << length(ch - rosen_solution<double>());
313 
314             DLIB_TEST(length(ch - rosen_solution<double>()) < 1e-5);
315 
316         }
317         print_spinner();
318         {
319             matrix<double,2,1> ch;
320 
321             ch = rosen_start<double>();
322 
323             solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
324                                 rosen_residual_double,
325                                 rosen_residual_derivative_double,
326                                 range(1,20),
327                                 ch);
328 
329             dlog << LINFO << "lm rosen obj: " << rosen(ch);
330             dlog << LINFO << "lm rosen error: " << length(ch - rosen_solution<double>());
331 
332             DLIB_TEST(length(ch - rosen_solution<double>()) < 1e-5);
333 
334         }
335 
336 
337 
338         print_spinner();
339         {
340             matrix<double,2,1> ch;
341 
342             ch = rosen_start<double>();
343 
344             solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
345                                 rosen_residual_double,
346                                 derivative(rosen_residual_double),
347                                 range(1,20),
348                                 ch);
349 
350             dlog << LINFO << "rosen obj: " << rosen(ch);
351             dlog << LINFO << "rosen error: " << length(ch - rosen_solution<double>());
352 
353             DLIB_TEST(length(ch - rosen_solution<double>()) < 1e-5);
354 
355         }
356         print_spinner();
357         {
358             matrix<float,2,1> ch;
359 
360             ch = rosen_start<float>();
361 
362             solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
363                                 rosen_residual_float,
364                                 derivative(rosen_residual_float),
365                                 range(1,20),
366                                 ch);
367 
368             dlog << LINFO << "float rosen obj: " << rosen(ch);
369             dlog << LINFO << "float rosen error: " << length(ch - rosen_solution<float>());
370 
371             DLIB_TEST(length(ch - rosen_solution<float>()) < 1e-5);
372 
373         }
374         print_spinner();
375         {
376             matrix<float,2,1> ch;
377 
378             ch = rosen_start<float>();
379 
380             solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
381                                 rosen_residual_float,
382                                 derivative(rosen_residual_float),
383                                 range(1,20),
384                                 ch);
385 
386             dlog << LINFO << "LM float rosen obj: " << rosen(ch);
387             dlog << LINFO << "LM float rosen error: " << length(ch - rosen_solution<float>());
388 
389             DLIB_TEST(length(ch - rosen_solution<float>()) < 1e-5);
390 
391         }
392         print_spinner();
393         {
394             matrix<double,2,1> ch;
395 
396             ch = rosen_start<double>();
397 
398             solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80),
399                                 rosen_residual_double,
400                                 derivative(rosen_residual_double),
401                                 range(1,20),
402                                 ch);
403 
404             dlog << LINFO << "LM rosen obj: " << rosen(ch);
405             dlog << LINFO << "LM rosen error: " << length(ch - rosen_solution<double>());
406 
407             DLIB_TEST(length(ch - rosen_solution<double>()) < 1e-5);
408 
409         }
410         print_spinner();
411         {
412             matrix<double,2,1> ch;
413 
414             ch = rosen_big_start<double>();
415 
416             solve_least_squares(objective_delta_stop_strategy(1e-13, 80),
417                                 rosen_big_residual_double,
418                                 derivative(rosen_big_residual_double),
419                                 range(1,2),
420                                 ch);
421 
422             dlog << LINFO << "rosen big obj: " << rosen_big(ch);
423             dlog << LINFO << "rosen big error: " << length(ch - rosen_big_solution<double>());
424 
425             DLIB_TEST(length(ch - rosen_big_solution<double>()) < 1e-5);
426 
427         }
428     }
429 
430 // ----------------------------------------------------------------------------------------
431 
432     class optimization_tester : public tester
433     {
434     public:
optimization_tester()435         optimization_tester (
436         ) :
437             tester ("test_least_squares",
438                     "Runs tests on the least squares optimization component.")
439         {}
440 
perform_test()441         void perform_test (
442         )
443         {
444             test_with_chebyquad();
445             test_with_brown();
446             test_with_rosen();
447         }
448     } a;
449 
450 }
451 
452 
453