1 // Copyright (C) 2010 Davis E. King (davis@dlib.net) 2 // License: Boost Software License See LICENSE.txt for the full license. 3 #ifndef DLIB_SVm_ONE_CLASS_TRAINER_Hh_ 4 #define DLIB_SVm_ONE_CLASS_TRAINER_Hh_ 5 6 #include "svm_one_class_trainer_abstract.h" 7 #include <cmath> 8 #include <limits> 9 #include <sstream> 10 #include "../matrix.h" 11 #include "../algs.h" 12 13 #include "function.h" 14 #include "kernel.h" 15 #include "../optimization/optimization_solve_qp3_using_smo.h" 16 17 namespace dlib 18 { 19 20 // ---------------------------------------------------------------------------------------- 21 22 template < 23 typename K 24 > 25 class svm_one_class_trainer 26 { 27 public: 28 typedef K kernel_type; 29 typedef typename kernel_type::scalar_type scalar_type; 30 typedef typename kernel_type::sample_type sample_type; 31 typedef typename kernel_type::mem_manager_type mem_manager_type; 32 typedef decision_function<kernel_type> trained_function_type; 33 svm_one_class_trainer()34 svm_one_class_trainer ( 35 ) : 36 nu(0.1), 37 cache_size(200), 38 eps(0.001) 39 { 40 } 41 svm_one_class_trainer(const kernel_type & kernel_,const scalar_type & nu_)42 svm_one_class_trainer ( 43 const kernel_type& kernel_, 44 const scalar_type& nu_ 45 ) : 46 kernel_function(kernel_), 47 nu(nu_), 48 cache_size(200), 49 eps(0.001) 50 { 51 // make sure requires clause is not broken 52 DLIB_ASSERT(0 < nu && nu <= 1, 53 "\tsvm_one_class_trainer::svm_one_class_trainer(kernel,nu)" 54 << "\n\t invalid inputs were given to this function" 55 << "\n\t nu: " << nu 56 ); 57 } 58 set_cache_size(long cache_size_)59 void set_cache_size ( 60 long cache_size_ 61 ) 62 { 63 // make sure requires clause is not broken 64 DLIB_ASSERT(cache_size_ > 0, 65 "\tvoid svm_one_class_trainer::set_cache_size(cache_size_)" 66 << "\n\t invalid inputs were given to this function" 67 << "\n\t cache_size: " << cache_size_ 68 ); 69 cache_size = cache_size_; 70 } 71 get_cache_size()72 long get_cache_size ( 73 ) const 74 { 75 return cache_size; 76 } 77 set_epsilon(scalar_type eps_)78 void set_epsilon ( 79 scalar_type eps_ 80 ) 81 { 82 // make sure requires clause is not broken 83 DLIB_ASSERT(eps_ > 0, 84 "\tvoid svm_one_class_trainer::set_epsilon(eps_)" 85 << "\n\t invalid inputs were given to this function" 86 << "\n\t eps: " << eps_ 87 ); 88 eps = eps_; 89 } 90 get_epsilon()91 const scalar_type get_epsilon ( 92 ) const 93 { 94 return eps; 95 } 96 set_kernel(const kernel_type & k)97 void set_kernel ( 98 const kernel_type& k 99 ) 100 { 101 kernel_function = k; 102 } 103 get_kernel()104 const kernel_type& get_kernel ( 105 ) const 106 { 107 return kernel_function; 108 } 109 set_nu(scalar_type nu_)110 void set_nu ( 111 scalar_type nu_ 112 ) 113 { 114 // make sure requires clause is not broken 115 DLIB_ASSERT(0 < nu_ && nu_ <= 1, 116 "\tvoid svm_one_class_trainer::set_nu(nu_)" 117 << "\n\t invalid inputs were given to this function" 118 << "\n\t nu: " << nu_ 119 ); 120 nu = nu_; 121 } 122 get_nu()123 const scalar_type get_nu ( 124 ) const 125 { 126 return nu; 127 } 128 129 template < 130 typename in_sample_vector_type 131 > train(const in_sample_vector_type & x)132 const decision_function<kernel_type> train ( 133 const in_sample_vector_type& x 134 ) const 135 { 136 return do_train(mat(x)); 137 } 138 swap(svm_one_class_trainer & item)139 void swap ( 140 svm_one_class_trainer& item 141 ) 142 { 143 exchange(kernel_function, item.kernel_function); 144 exchange(nu, item.nu); 145 exchange(cache_size, item.cache_size); 146 exchange(eps, item.eps); 147 } 148 149 private: 150 151 // ------------------------------------------------------------------------------------ 152 153 template < 154 typename in_sample_vector_type 155 > do_train(const in_sample_vector_type & x)156 const decision_function<kernel_type> do_train ( 157 const in_sample_vector_type& x 158 ) const 159 { 160 typedef typename K::scalar_type scalar_type; 161 typedef typename decision_function<K>::sample_vector_type sample_vector_type; 162 typedef typename decision_function<K>::scalar_vector_type scalar_vector_type; 163 164 // make sure requires clause is not broken 165 DLIB_ASSERT(is_col_vector(x) && x.size() > 0, 166 "\tdecision_function svm_one_class_trainer::train(x)" 167 << "\n\t invalid inputs were given to this function" 168 << "\n\t x.nr(): " << x.nr() 169 << "\n\t x.nc(): " << x.nc() 170 ); 171 172 173 scalar_vector_type alpha; 174 175 solve_qp3_using_smo<scalar_vector_type> solver; 176 177 solver(symmetric_matrix_cache<float>(kernel_matrix(kernel_function,x), cache_size), 178 zeros_matrix<scalar_type>(x.size(),1), 179 ones_matrix<scalar_type>(x.size(),1), 180 nu*x.size(), 181 1, 182 1, 183 alpha, 184 eps); 185 186 scalar_type rho; 187 calculate_rho(alpha,solver.get_gradient(),rho); 188 189 190 // count the number of support vectors 191 const long sv_count = (long)sum(alpha != 0); 192 193 scalar_vector_type sv_alpha; 194 sample_vector_type support_vectors; 195 196 // size these column vectors so that they have an entry for each support vector 197 sv_alpha.set_size(sv_count); 198 support_vectors.set_size(sv_count); 199 200 // load the support vectors and their alpha values into these new column matrices 201 long idx = 0; 202 for (long i = 0; i < alpha.nr(); ++i) 203 { 204 if (alpha(i) != 0) 205 { 206 sv_alpha(idx) = alpha(i); 207 support_vectors(idx) = x(i); 208 ++idx; 209 } 210 } 211 212 // now return the decision function 213 return decision_function<K> (sv_alpha, rho, kernel_function, support_vectors); 214 } 215 216 // ------------------------------------------------------------------------------------ 217 218 template < 219 typename scalar_vector_type 220 > calculate_rho(const scalar_vector_type & alpha,const scalar_vector_type & df,scalar_type & rho)221 void calculate_rho( 222 const scalar_vector_type& alpha, 223 const scalar_vector_type& df, 224 scalar_type& rho 225 ) const 226 { 227 using namespace std; 228 long num_p_free = 0; 229 scalar_type sum_p_free = 0; 230 231 232 scalar_type upper_bound_p; 233 scalar_type lower_bound_p; 234 235 find_min_and_max(df, upper_bound_p, lower_bound_p); 236 237 for(long i = 0; i < alpha.nr(); ++i) 238 { 239 if(alpha(i) == 1) 240 { 241 if (df(i) > upper_bound_p) 242 upper_bound_p = df(i); 243 } 244 else if(alpha(i) == 0) 245 { 246 if (df(i) < lower_bound_p) 247 lower_bound_p = df(i); 248 } 249 else 250 { 251 ++num_p_free; 252 sum_p_free += df(i); 253 } 254 } 255 256 scalar_type r1; 257 if(num_p_free > 0) 258 r1 = sum_p_free/num_p_free; 259 else 260 r1 = (upper_bound_p+lower_bound_p)/2; 261 262 rho = r1; 263 } 264 265 kernel_type kernel_function; 266 scalar_type nu; 267 long cache_size; 268 scalar_type eps; 269 }; // end of class svm_one_class_trainer 270 271 // ---------------------------------------------------------------------------------------- 272 273 template <typename K> swap(svm_one_class_trainer<K> & a,svm_one_class_trainer<K> & b)274 void swap ( 275 svm_one_class_trainer<K>& a, 276 svm_one_class_trainer<K>& b 277 ) { a.swap(b); } 278 279 // ---------------------------------------------------------------------------------------- 280 281 } 282 283 #endif // DLIB_SVm_ONE_CLASS_TRAINER_Hh_ 284 285