1 // This is mul/clsfy/clsfy_logit_loss_function.h 2 #ifndef clsfy_logit_loss_function_h_ 3 #define clsfy_logit_loss_function_h_ 4 //: 5 // \file 6 // \brief Loss function for logit of linear classifier 7 // \author TFC 8 9 #include <vnl/vnl_cost_function.h> 10 #include <mbl/mbl_data_wrapper.h> 11 12 //: Loss function for logit of linear classifier. 13 // For vector v' = (b w') (ie b=y[0], w=(y[1]...y[n])), computes 14 // r(v) - (1/n_eg)sum log[(1-minp)logit(c_i * (b+w.x_i)) + minp] 15 // 16 // This is the sum of log prob of correct classification (+regulariser) 17 // which should be minimised to train the classifier. 18 // 19 // Note: Regularisor only important to deal with case where perfect 20 // classification possible, where scaling v would always increase f(v). 21 // Plausible choice of regularisor is clsfy_quad_regulariser (below) 22 class clsfy_logit_loss_function : public vnl_cost_function 23 { 24 private: 25 mbl_data_wrapper<vnl_vector<double> >& x_; 26 27 //: c[i] = -1 or +1, indicating class of x[i] 28 const vnl_vector<double> & c_; 29 30 //: Min probability (avoids log(zero)) 31 double min_p_; 32 33 //: Optional regularising function 34 vnl_cost_function *reg_fn_; 35 public: 36 clsfy_logit_loss_function(mbl_data_wrapper<vnl_vector<double> >& x, 37 const vnl_vector<double> & c, 38 double min_p, vnl_cost_function* reg_fn=nullptr); 39 40 //: The main function: Compute f(v) 41 double f(vnl_vector<double> const& v) override; 42 43 //: Calculate the gradient of f at parameter vector v. 44 void gradf(vnl_vector<double> const& v, 45 vnl_vector<double>& gradient) override; 46 47 //: Compute f(v) and its gradient (if non-zero pointers supplied) 48 void compute(vnl_vector<double> const& v, 49 double *f, vnl_vector<double>* gradient) override; 50 51 }; 52 53 //: Simple quadratic term used to regularise functions 54 // For vector v' = (b w') (ie b=y[0], w=(y[1]...y[n])), computes 55 // f(v) = alpha*|w|^2 (ie ignores first element, which is bias of linear classifier) 56 class clsfy_quad_regulariser : public vnl_cost_function 57 { 58 private: 59 //: Scaling factor 60 double alpha_; 61 public: 62 clsfy_quad_regulariser(double alpha=1e-6); 63 64 //: The main function: Compute f(v) 65 double f(vnl_vector<double> const& v) override; 66 67 //: Calculate the gradient of f at parameter vector v. 68 void gradf(vnl_vector<double> const& v, 69 vnl_vector<double>& gradient) override; 70 }; 71 72 73 #endif // clsfy_logit_loss_function_h_ 74