1 // This is mul/clsfy/clsfy_logit_loss_function.h
2 #ifndef clsfy_logit_loss_function_h_
3 #define clsfy_logit_loss_function_h_
4 //:
5 // \file
6 // \brief Loss function for logit of linear classifier
7 // \author TFC
8 
9 #include <vnl/vnl_cost_function.h>
10 #include <mbl/mbl_data_wrapper.h>
11 
12 //: Loss function for logit of linear classifier.
13 //  For vector v' = (b w') (ie b=y[0], w=(y[1]...y[n])), computes
14 //  r(v) - (1/n_eg)sum log[(1-minp)logit(c_i * (b+w.x_i)) + minp]
15 //
16 // This is the sum of log prob of correct classification (+regulariser)
17 // which should be minimised to train the classifier.
18 //
19 // Note: Regularisor only important to deal with case where perfect
20 // classification possible, where scaling v would always increase f(v).
21 // Plausible choice of regularisor is clsfy_quad_regulariser (below)
22 class clsfy_logit_loss_function : public vnl_cost_function
23 {
24 private:
25   mbl_data_wrapper<vnl_vector<double> >& x_;
26 
27   //: c[i] = -1 or +1, indicating class of x[i]
28   const vnl_vector<double> & c_;
29 
30   //: Min probability (avoids log(zero))
31   double min_p_;
32 
33   //: Optional regularising function
34   vnl_cost_function *reg_fn_;
35 public:
36   clsfy_logit_loss_function(mbl_data_wrapper<vnl_vector<double> >& x,
37                             const vnl_vector<double> & c,
38                             double min_p, vnl_cost_function* reg_fn=nullptr);
39 
40   //:  The main function: Compute f(v)
41   double f(vnl_vector<double> const& v) override;
42 
43   //:  Calculate the gradient of f at parameter vector v.
44   void gradf(vnl_vector<double> const& v,
45                      vnl_vector<double>& gradient) override;
46 
47   //: Compute f(v) and its gradient (if non-zero pointers supplied)
48   void compute(vnl_vector<double> const& v,
49                        double *f, vnl_vector<double>* gradient) override;
50 
51 };
52 
53 //: Simple quadratic term used to regularise functions
54 //  For vector v' = (b w') (ie b=y[0], w=(y[1]...y[n])), computes
55 //  f(v) = alpha*|w|^2   (ie ignores first element, which is bias of linear classifier)
56 class clsfy_quad_regulariser : public vnl_cost_function
57 {
58 private:
59   //: Scaling factor
60   double alpha_;
61 public:
62   clsfy_quad_regulariser(double alpha=1e-6);
63 
64   //:  The main function: Compute f(v)
65   double f(vnl_vector<double> const& v) override;
66 
67   //:  Calculate the gradient of f at parameter vector v.
68   void gradf(vnl_vector<double> const& v,
69                      vnl_vector<double>& gradient) override;
70 };
71 
72 
73 #endif // clsfy_logit_loss_function_h_
74