// Copyright (C) 2010 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #include #include "optimization_test_functions.h" #include #include #include #include #include #include "../rand.h" #include "tester.h" namespace { using namespace test; using namespace dlib; using namespace std; using namespace dlib::test_functions; logger dlog("test.least_squares"); // ---------------------------------------------------------------------------------------- void test_with_chebyquad() { print_spinner(); { matrix ch; ch = chebyquad_start(2); solve_least_squares(objective_delta_stop_strategy(1e-13, 80), chebyquad_residual, derivative(chebyquad_residual), range(0,ch.size()-1), ch); dlog << LINFO << "chebyquad 2 obj: " << chebyquad(ch); dlog << LINFO << "chebyquad 2 der: " << length(chebyquad_derivative(ch)); dlog << LINFO << "chebyquad 2 error: " << length(ch - chebyquad_solution(2)); DLIB_TEST(length(ch - chebyquad_solution(2)) < 1e-5); } { matrix ch; ch = chebyquad_start(2); solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), chebyquad_residual, derivative(chebyquad_residual), range(0,ch.size()-1), ch); dlog << LINFO << "LM chebyquad 2 obj: " << chebyquad(ch); dlog << LINFO << "LM chebyquad 2 der: " << length(chebyquad_derivative(ch)); dlog << LINFO << "LM chebyquad 2 error: " << length(ch - chebyquad_solution(2)); DLIB_TEST(length(ch - chebyquad_solution(2)) < 1e-5); } print_spinner(); { matrix ch; ch = chebyquad_start(2); solve_least_squares(objective_delta_stop_strategy(1e-13, 80), chebyquad_residual, derivative(chebyquad_residual), range(0,ch.size()-1), ch); dlog << LINFO << "chebyquad 2 obj: " << chebyquad(ch); dlog << LINFO << "chebyquad 2 der: " << length(chebyquad_derivative(ch)); dlog << LINFO << "chebyquad 2 error: " << length(ch - chebyquad_solution(2)); DLIB_TEST(length(ch - chebyquad_solution(2)) < 1e-5); } print_spinner(); { matrix ch; ch = chebyquad_start(2); solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), chebyquad_residual, derivative(chebyquad_residual), range(0,ch.size()-1), ch); dlog << LINFO << "LM chebyquad 2 obj: " << chebyquad(ch); dlog << LINFO << "LM chebyquad 2 der: " << length(chebyquad_derivative(ch)); dlog << LINFO << "LM chebyquad 2 error: " << length(ch - chebyquad_solution(2)); DLIB_TEST(length(ch - chebyquad_solution(2)) < 1e-5); } print_spinner(); { matrix ch; ch = chebyquad_start(4); solve_least_squares(objective_delta_stop_strategy(1e-13, 80), chebyquad_residual, derivative(chebyquad_residual), range(0,ch.size()-1), ch); dlog << LINFO << "chebyquad 4 obj: " << chebyquad(ch); dlog << LINFO << "chebyquad 4 der: " << length(chebyquad_derivative(ch)); dlog << LINFO << "chebyquad 4 error: " << length(ch - chebyquad_solution(4)); DLIB_TEST(length(ch - chebyquad_solution(4)) < 1e-5); } print_spinner(); { matrix ch; ch = chebyquad_start(4); solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), chebyquad_residual, derivative(chebyquad_residual), range(0,ch.size()-1), ch); dlog << LINFO << "LM chebyquad 4 obj: " << chebyquad(ch); dlog << LINFO << "LM chebyquad 4 der: " << length(chebyquad_derivative(ch)); dlog << LINFO << "LM chebyquad 4 error: " << length(ch - chebyquad_solution(4)); DLIB_TEST(length(ch - chebyquad_solution(4)) < 1e-5); } print_spinner(); { matrix ch; ch = chebyquad_start(6); solve_least_squares(objective_delta_stop_strategy(1e-13, 80), chebyquad_residual, derivative(chebyquad_residual), range(0,ch.size()-1), ch); dlog << LINFO << "chebyquad 6 obj: " << chebyquad(ch); dlog << LINFO << "chebyquad 6 der: " << length(chebyquad_derivative(ch)); dlog << LINFO << "chebyquad 6 error: " << length(ch - chebyquad_solution(6)); // the ch variable contains a permutation of what is in chebyquad_solution(6). // Apparently there is more than one minimum?. Just check that the objective // goes to zero. DLIB_TEST(chebyquad(ch) < 1e-10); } print_spinner(); { matrix ch; ch = chebyquad_start(6); solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), chebyquad_residual, derivative(chebyquad_residual), range(0,ch.size()-1), ch); dlog << LINFO << "LM chebyquad 6 obj: " << chebyquad(ch); dlog << LINFO << "LM chebyquad 6 der: " << length(chebyquad_derivative(ch)); dlog << LINFO << "LM chebyquad 6 error: " << length(ch - chebyquad_solution(6)); DLIB_TEST(chebyquad(ch) < 1e-10); } print_spinner(); { matrix ch; ch = chebyquad_start(8); solve_least_squares(objective_delta_stop_strategy(1e-13, 80), chebyquad_residual, derivative(chebyquad_residual), range(0,ch.size()-1), ch); dlog << LINFO << "chebyquad 8 obj: " << chebyquad(ch); dlog << LINFO << "chebyquad 8 der: " << length(chebyquad_derivative(ch)); dlog << LINFO << "chebyquad 8 error: " << length(ch - chebyquad_solution(8)); DLIB_TEST(length(ch - chebyquad_solution(8)) < 1e-5); } print_spinner(); { matrix ch; ch = chebyquad_start(8); solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), chebyquad_residual, derivative(chebyquad_residual), range(0,ch.size()-1), ch); dlog << LINFO << "LM chebyquad 8 obj: " << chebyquad(ch); dlog << LINFO << "LM chebyquad 8 der: " << length(chebyquad_derivative(ch)); dlog << LINFO << "LM chebyquad 8 error: " << length(ch - chebyquad_solution(8)); DLIB_TEST(length(ch - chebyquad_solution(8)) < 1e-5); } } // ---------------------------------------------------------------------------------------- void test_with_brown() { print_spinner(); { matrix ch; ch = brown_start(); solve_least_squares(objective_delta_stop_strategy(1e-13, 300), brown_residual, derivative(brown_residual), range(1,20), ch); dlog << LINFO << "brown obj: " << brown(ch); dlog << LINFO << "brown der: " << length(brown_derivative(ch)); dlog << LINFO << "brown error: " << length(ch - brown_solution()); DLIB_TEST_MSG(length(ch - brown_solution()) < 1e-5,length(ch - brown_solution()) ); } print_spinner(); { matrix ch; ch = brown_start(); solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), brown_residual, derivative(brown_residual), range(1,20), ch); dlog << LINFO << "LM brown obj: " << brown(ch); dlog << LINFO << "LM brown der: " << length(brown_derivative(ch)); dlog << LINFO << "LM brown error: " << length(ch - brown_solution()); DLIB_TEST(length(ch - brown_solution()) < 1e-5); } } // ---------------------------------------------------------------------------------------- // These functions are declared here because wrapping the real rosen functions in this // way avoids triggering a bug in visual studio 2005 which prevents this code from compiling. double rosen_residual_double (int i, const matrix& m) { return rosen_residual(i,m); } float rosen_residual_float (int i, const matrix& m) { return rosen_residual(i,m); } matrix rosen_residual_derivative_double (int i, const matrix& m) { return rosen_residual_derivative(i,m); } /* matrix rosen_residual_derivative_float (int i, const matrix& m) { return rosen_residual_derivative(i,m); } */ double rosen_big_residual_double (int i, const matrix& m) { return rosen_big_residual(i,m); } // ---------------------------------------------------------------------------------------- void test_with_rosen() { print_spinner(); { matrix ch; ch = rosen_start(); solve_least_squares(objective_delta_stop_strategy(1e-13, 80), rosen_residual_double, rosen_residual_derivative_double, range(1,20), ch); dlog << LINFO << "rosen obj: " << rosen(ch); dlog << LINFO << "rosen error: " << length(ch - rosen_solution()); DLIB_TEST(length(ch - rosen_solution()) < 1e-5); } print_spinner(); { matrix ch; ch = rosen_start(); solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), rosen_residual_double, rosen_residual_derivative_double, range(1,20), ch); dlog << LINFO << "lm rosen obj: " << rosen(ch); dlog << LINFO << "lm rosen error: " << length(ch - rosen_solution()); DLIB_TEST(length(ch - rosen_solution()) < 1e-5); } print_spinner(); { matrix ch; ch = rosen_start(); solve_least_squares(objective_delta_stop_strategy(1e-13, 80), rosen_residual_double, derivative(rosen_residual_double), range(1,20), ch); dlog << LINFO << "rosen obj: " << rosen(ch); dlog << LINFO << "rosen error: " << length(ch - rosen_solution()); DLIB_TEST(length(ch - rosen_solution()) < 1e-5); } print_spinner(); { matrix ch; ch = rosen_start(); solve_least_squares(objective_delta_stop_strategy(1e-13, 80), rosen_residual_float, derivative(rosen_residual_float), range(1,20), ch); dlog << LINFO << "float rosen obj: " << rosen(ch); dlog << LINFO << "float rosen error: " << length(ch - rosen_solution()); DLIB_TEST(length(ch - rosen_solution()) < 1e-5); } print_spinner(); { matrix ch; ch = rosen_start(); solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), rosen_residual_float, derivative(rosen_residual_float), range(1,20), ch); dlog << LINFO << "LM float rosen obj: " << rosen(ch); dlog << LINFO << "LM float rosen error: " << length(ch - rosen_solution()); DLIB_TEST(length(ch - rosen_solution()) < 1e-5); } print_spinner(); { matrix ch; ch = rosen_start(); solve_least_squares_lm(objective_delta_stop_strategy(1e-13, 80), rosen_residual_double, derivative(rosen_residual_double), range(1,20), ch); dlog << LINFO << "LM rosen obj: " << rosen(ch); dlog << LINFO << "LM rosen error: " << length(ch - rosen_solution()); DLIB_TEST(length(ch - rosen_solution()) < 1e-5); } print_spinner(); { matrix ch; ch = rosen_big_start(); solve_least_squares(objective_delta_stop_strategy(1e-13, 80), rosen_big_residual_double, derivative(rosen_big_residual_double), range(1,2), ch); dlog << LINFO << "rosen big obj: " << rosen_big(ch); dlog << LINFO << "rosen big error: " << length(ch - rosen_big_solution()); DLIB_TEST(length(ch - rosen_big_solution()) < 1e-5); } } // ---------------------------------------------------------------------------------------- class optimization_tester : public tester { public: optimization_tester ( ) : tester ("test_least_squares", "Runs tests on the least squares optimization component.") {} void perform_test ( ) { test_with_chebyquad(); test_with_brown(); test_with_rosen(); } } a; }