1 #ifndef FCL_TEST_LIBSVM_CLASSIFIER_H
2 #define FCL_TEST_LIBSVM_CLASSIFIER_H
3 
4 #include "fcl/learning/classifier.h"
5 #include <libsvm/svm.h>
6 
7 namespace fcl
8 {
9 
10 
11 template<std::size_t N>
12 class LibSVMClassifier : public SVMClassifier<N>
13 {
14 public:
LibSVMClassifier()15   LibSVMClassifier()
16   {
17     param.svm_type = C_SVC;
18     param.kernel_type = RBF;
19     param.degree = 3;
20     param.gamma = 0;	// 1/num_features
21     param.coef0 = 0;
22     param.nu = 0.5;
23     param.cache_size = 100; // can change
24     param.C = 1;
25     param.eps = 1e-3;
26     param.p = 0.1;
27     param.shrinking = 1;    // use shrinking
28     param.probability = 0;
29     param.nr_weight = 0;
30     param.weight_label = NULL;
31     param.weight = NULL;
32 
33     param.nr_weight = 2;
34     param.weight_label = (int *)realloc(param.weight_label, sizeof(int) * param.nr_weight);
35     param.weight = (double *)realloc(param.weight, sizeof(double) * param.nr_weight);
36     param.weight_label[0] = -1;
37     param.weight_label[1] = 1;
38     param.weight[0] = 1;
39     param.weight[1] = 1;
40 
41     model = NULL;
42     x_space = NULL;
43     problem.x = NULL;
44     problem.y = NULL;
45     problem.W = NULL;
46   }
47 
48 
setCSVM()49   void setCSVM() { param.svm_type = C_SVC; }
setNuSVM()50   void setNuSVM() { param.svm_type = NU_SVC; }
setC(FCL_REAL C)51   void setC(FCL_REAL C) { param.C = C; }
setNu(FCL_REAL nu)52   void setNu(FCL_REAL nu) { param.nu = nu; }
setLinearClassifier()53   void setLinearClassifier() { param.kernel_type = LINEAR; }
setNonLinearClassifier()54   void setNonLinearClassifier() { param.kernel_type = RBF; }
setProbability(bool use_probability)55   void setProbability(bool use_probability) { param.probability = use_probability; }
setScaler(const Scaler<N> & scaler_)56   virtual void setScaler(const Scaler<N>& scaler_)
57   {
58     scaler = scaler_;
59   }
60 
setNegativeWeight(FCL_REAL c)61   void setNegativeWeight(FCL_REAL c)
62   {
63     param.weight[0] = c;
64   }
65 
setPositiveWeight(FCL_REAL c)66   void setPositiveWeight(FCL_REAL c)
67   {
68     param.weight[1] = c;
69   }
70 
setEPS(FCL_REAL e)71   void setEPS(FCL_REAL e)
72   {
73     param.eps = e;
74   }
75 
setGamma(FCL_REAL gamma)76   void setGamma(FCL_REAL gamma)
77   {
78     param.gamma = gamma;
79   }
80 
~LibSVMClassifier()81   ~LibSVMClassifier()
82   {
83     svm_destroy_param(&param);
84     svm_free_and_destroy_model(&model);
85     delete [] x_space;
86     delete [] problem.x;
87     delete [] problem.y;
88     delete [] problem.W;
89   }
90 
learn(const std::vector<Item<N>> & data)91   virtual void learn(const std::vector<Item<N> >& data)
92   {
93     if(data.size() == 0) return;
94 
95     if(model) svm_free_and_destroy_model(&model);
96     if(param.gamma == 0) param.gamma = 1.0 / N;
97 
98     problem.l = data.size();
99     if(problem.y) delete [] problem.y;
100     problem.y = new double [problem.l];
101     if(problem.x) delete [] problem.x;
102     problem.x = new svm_node* [problem.l];
103     if(problem.W) delete [] problem.W;
104     problem.W = new double [problem.l];
105     if(x_space) delete [] x_space;
106     x_space = new svm_node [(N + 1) * problem.l];
107 
108     for(std::size_t i = 0; i < data.size(); ++i)
109     {
110       svm_node* cur_x_space = x_space + (N + 1) * i;
111       Vecnf<N> q_scaled = scaler.scale(data[i].q);
112       for(std::size_t j = 0; j < N; ++j)
113       {
114         cur_x_space[j].index = j + 1;
115         cur_x_space[j].value = q_scaled[j];
116       }
117       cur_x_space[N].index = -1;
118 
119       problem.x[i] = cur_x_space;
120       problem.y[i] = (data[i].label ? 1 : -1);
121       problem.W[i] = data[i].w;
122     }
123 
124     model = svm_train(&problem, &param);
125     hyperw_normsqr = svm_hyper_w_normsqr_twoclass(model);
126   }
127 
predict(const std::vector<Vecnf<N>> & qs)128   virtual std::vector<PredictResult> predict(const std::vector<Vecnf<N> >& qs) const
129   {
130     std::vector<PredictResult> predict_results;
131 
132     int nr_class = svm_get_nr_class(model);
133     double* prob_estimates = NULL;
134 
135     svm_node* x = (svm_node*)malloc((N + 1) * sizeof(svm_node));
136     if(param.probability)
137       prob_estimates = (double*)malloc(nr_class * sizeof(double));
138 
139     Vecnf<N> v;
140     for(std::size_t i = 0; i < qs.size(); ++i)
141     {
142       v = scaler.scale(qs[i]);
143       for(std::size_t j = 0; j < N; ++j)
144       {
145         x[j].index = j + 1;
146         x[j].value = v[j];
147       }
148       x[N].index = -1;
149 
150       double predict_label;
151 
152       if(param.probability)
153       {
154         predict_label = svm_predict_probability(model, x, prob_estimates);
155         predict_label = (predict_label > 0) ? 1 : 0;
156         predict_results.push_back(PredictResult(predict_label, *prob_estimates));
157       }
158       else
159       {
160         predict_label = svm_predict(model, x);
161         predict_label = (predict_label > 0) ? 1 : 0;
162         predict_results.push_back(PredictResult(predict_label));
163       }
164     }
165 
166     if(param.probability) free(prob_estimates);
167     free(x);
168 
169     return predict_results;
170   }
171 
predict(const Vecnf<N> & q)172   virtual PredictResult predict(const Vecnf<N>& q) const
173   {
174     return (predict(std::vector<Vecnf<N> >(1, q)))[0];
175   }
176 
save(const std::string & filename)177   void save(const std::string& filename) const
178   {
179     if(model)
180       svm_save_model(filename.c_str(), model);
181   }
182 
getSupportVectors()183   virtual std::vector<Item<N> > getSupportVectors() const
184   {
185     std::vector<Item<N> > results;
186     Item<N> item;
187     for(std::size_t i = 0; i < (std::size_t)model->l; ++i)
188     {
189       for(std::size_t j = 0; j < N; ++j)
190         item.q[j] = model->SV[i][j].value;
191       item.q = scaler.unscale(item.q);
192       int id = model->sv_indices[i] - 1;
193       item.label = (problem.y[id] > 0);
194       results.push_back(item);
195     }
196 
197     return results;
198   }
199 
200   svm_parameter param;
201   svm_problem problem;
202   svm_node* x_space;
203   svm_model* model;
204   double hyperw_normsqr;
205 
206   Scaler<N> scaler;
207 };
208 
209 
210 }
211 
212 #endif
213