1 // This is mul/clsfy/tests/test_logit_loss_function.cxx
2 #include <iostream>
3 #include <string>
4 #include "testlib/testlib_test.h"
5 //:
6 // \file
7 // \brief Tests the clsfy_logit_loss_function class
8 // \author TFC
9
10 #ifdef _MSC_VER
11 # include "vcl_msvc_warnings.h"
12 #endif
13 #include "vpl/vpl.h" // vpl_unlink()
14 #include <clsfy/clsfy_logit_loss_function.h>
15 #include <vpdfl/vpdfl_axis_gaussian.h>
16 #include <vpdfl/vpdfl_axis_gaussian_sampler.h>
17 #include <mbl/mbl_data_array_wrapper.h>
18
19 //: Tests the clsfy_logit_loss_function class
test_logit_loss_function()20 void test_logit_loss_function()
21 {
22 std::cout << "**************************************\n"
23 << " Testing clsfy_logit_loss_function_builder\n"
24 << "**************************************\n";
25
26 // Create data
27 unsigned n_egs=20;
28 unsigned n_dim=10;
29 vnl_random rand1(3857);
30 std::vector<vnl_vector<double> > data(n_egs);
31 vnl_vector<double> class_id(n_egs);
32
33 vnl_vector<double> mean1(n_dim,0.0),mean2(n_dim,1.1),var(n_dim,n_dim);
34
35 vpdfl_axis_gaussian pdf1,pdf2;
36 pdf1.set(mean1,var);
37 pdf2.set(mean2,var);
38 vpdfl_axis_gaussian_sampler sampler1,sampler2;
39 sampler1.set_model(pdf1);
40 sampler2.set_model(pdf2);
41
42 for (unsigned i=0;i<n_egs;++i)
43 {
44 if (i%2==0)
45 {
46 sampler1.sample(data[i]);
47 class_id[i]=1.0;
48 }
49 else
50 {
51 sampler2.sample(data[i]);
52 class_id[i]=-1.0;
53 }
54 }
55
56 double min_p=0.001, alpha=1.0;
57 mbl_data_array_wrapper<vnl_vector<double> > data1(data);
58 clsfy_quad_regulariser quad_reg(alpha);
59 clsfy_logit_loss_function fn(data1,class_id,min_p,&quad_reg);
60
61 vnl_vector<double> w(n_dim+1,0.5);
62 for (unsigned i=0;i<=n_dim;++i) w[i]=0.1*(i+1);
63 std::cout<<"f(w)="<<fn.f(w)<<std::endl;
64
65 // Test the gradient
66 double f0 = fn.f(w);
67 vnl_vector<double> gradient;
68 fn.gradf(w,gradient);
69 double d = 1e-6;
70 for (unsigned i=0;i<=n_dim;++i)
71 {
72 w[i]+=d;
73 double gi = (fn.f(w)-f0)/d;
74 w[i]-=d;
75
76 TEST_NEAR("Gradient",gradient[i],gi,1e-5);
77 }
78
79 double f2;
80 vnl_vector<double> g2;
81 fn.compute(w,&f2,&g2);
82
83 TEST_NEAR("compute f",f0,f2,1e-6);
84 TEST_NEAR("compute g",(gradient-g2).rms(),0,1e-5);
85 }
86
87 TESTMAIN(test_logit_loss_function);
88