1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
4 #include <ctype.h>
5 #include <errno.h>
6 #include "svm.h"
7 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
8
print_null(const char * s)9 void print_null(const char *s) {}
10
exit_with_help()11 void exit_with_help()
12 {
13 printf(
14 "Usage: svm-train [options] training_set_file [model_file]\n"
15 "options:\n"
16 "-s svm_type : set type of SVM (default 0)\n"
17 " 0 -- C-SVC (multi-class classification)\n"
18 " 1 -- nu-SVC (multi-class classification)\n"
19 " 2 -- one-class SVM\n"
20 " 3 -- epsilon-SVR (regression)\n"
21 " 4 -- nu-SVR (regression)\n"
22 "-t kernel_type : set type of kernel function (default 2)\n"
23 " 0 -- linear: u'*v\n"
24 " 1 -- polynomial: (gamma*u'*v + coef0)^degree\n"
25 " 2 -- radial basis function: exp(-gamma*|u-v|^2)\n"
26 " 3 -- sigmoid: tanh(gamma*u'*v + coef0)\n"
27 " 4 -- precomputed kernel (kernel values in training_set_file)\n"
28 "-d degree : set degree in kernel function (default 3)\n"
29 "-g gamma : set gamma in kernel function (default 1/num_features)\n"
30 "-r coef0 : set coef0 in kernel function (default 0)\n"
31 "-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)\n"
32 "-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)\n"
33 "-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)\n"
34 "-m cachesize : set cache memory size in MB (default 100)\n"
35 "-e epsilon : set tolerance of termination criterion (default 0.001)\n"
36 "-h shrinking : whether to use the shrinking heuristics, 0 or 1 (default 1)\n"
37 "-b probability_estimates : whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)\n"
38 "-wi weight : set the parameter C of class i to weight*C, for C-SVC (default 1)\n"
39 "-v n: n-fold cross validation mode\n"
40 "-q : quiet mode (no outputs)\n"
41 );
42 exit(1);
43 }
44
exit_input_error(int line_num)45 void exit_input_error(int line_num)
46 {
47 fprintf(stderr,"Wrong input format at line %d\n", line_num);
48 exit(1);
49 }
50
51 void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name);
52 void read_problem(const char *filename);
53 void do_cross_validation();
54
55 struct svm_parameter param; // set by parse_command_line
56 struct svm_problem prob; // set by read_problem
57 struct svm_model *model;
58 struct svm_node *x_space;
59 int cross_validation;
60 int nr_fold;
61
62 static char *line = NULL;
63 static int max_line_len;
64
readline(FILE * input)65 static char* readline(FILE *input)
66 {
67 int len;
68
69 if(fgets(line,max_line_len,input) == NULL)
70 return NULL;
71
72 while(strrchr(line,'\n') == NULL)
73 {
74 max_line_len *= 2;
75 line = (char *) realloc(line,max_line_len);
76 len = (int) strlen(line);
77 if(fgets(line+len,max_line_len-len,input) == NULL)
78 break;
79 }
80 return line;
81 }
82
main(int argc,char ** argv)83 int main(int argc, char **argv)
84 {
85 char input_file_name[1024];
86 char model_file_name[1024];
87 const char *error_msg;
88
89 parse_command_line(argc, argv, input_file_name, model_file_name);
90 read_problem(input_file_name);
91 error_msg = svm_check_parameter(&prob,¶m);
92
93 if(error_msg)
94 {
95 fprintf(stderr,"ERROR: %s\n",error_msg);
96 exit(1);
97 }
98
99 if(cross_validation)
100 {
101 do_cross_validation();
102 }
103 else
104 {
105 model = svm_train(&prob,¶m);
106 if(svm_save_model(model_file_name,model))
107 {
108 fprintf(stderr, "can't save model to file %s\n", model_file_name);
109 exit(1);
110 }
111 svm_free_and_destroy_model(&model);
112 }
113 svm_destroy_param(¶m);
114 free(prob.y);
115 free(prob.x);
116 free(x_space);
117 free(line);
118
119 return 0;
120 }
121
do_cross_validation()122 void do_cross_validation()
123 {
124 int i;
125 int total_correct = 0;
126 double total_error = 0;
127 double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
128 double *target = Malloc(double,prob.l);
129
130 svm_cross_validation(&prob,¶m,nr_fold,target);
131 if(param.svm_type == EPSILON_SVR ||
132 param.svm_type == NU_SVR)
133 {
134 for(i=0;i<prob.l;i++)
135 {
136 double y = prob.y[i];
137 double v = target[i];
138 total_error += (v-y)*(v-y);
139 sumv += v;
140 sumy += y;
141 sumvv += v*v;
142 sumyy += y*y;
143 sumvy += v*y;
144 }
145 printf("Cross Validation Mean squared error = %g\n",total_error/prob.l);
146 printf("Cross Validation Squared correlation coefficient = %g\n",
147 ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
148 ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))
149 );
150 }
151 else
152 {
153 for(i=0;i<prob.l;i++)
154 if(target[i] == prob.y[i])
155 ++total_correct;
156 printf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l);
157 }
158 free(target);
159 }
160
parse_command_line(int argc,char ** argv,char * input_file_name,char * model_file_name)161 void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name)
162 {
163 int i;
164 void (*print_func)(const char*) = NULL; // default printing to stdout
165
166 // default values
167 param.svm_type = C_SVC;
168 param.kernel_type = RBF;
169 param.degree = 3;
170 param.gamma = 0; // 1/num_features
171 param.coef0 = 0;
172 param.nu = 0.5;
173 param.cache_size = 100;
174 param.C = 1;
175 param.eps = 1e-3;
176 param.p = 0.1;
177 param.shrinking = 1;
178 param.probability = 0;
179 param.nr_weight = 0;
180 param.weight_label = NULL;
181 param.weight = NULL;
182 cross_validation = 0;
183
184 // parse options
185 for(i=1;i<argc;i++)
186 {
187 if(argv[i][0] != '-') break;
188 if(++i>=argc)
189 exit_with_help();
190 switch(argv[i-1][1])
191 {
192 case 's':
193 param.svm_type = atoi(argv[i]);
194 break;
195 case 't':
196 param.kernel_type = atoi(argv[i]);
197 break;
198 case 'd':
199 param.degree = atoi(argv[i]);
200 break;
201 case 'g':
202 param.gamma = atof(argv[i]);
203 break;
204 case 'r':
205 param.coef0 = atof(argv[i]);
206 break;
207 case 'n':
208 param.nu = atof(argv[i]);
209 break;
210 case 'm':
211 param.cache_size = atof(argv[i]);
212 break;
213 case 'c':
214 param.C = atof(argv[i]);
215 break;
216 case 'e':
217 param.eps = atof(argv[i]);
218 break;
219 case 'p':
220 param.p = atof(argv[i]);
221 break;
222 case 'h':
223 param.shrinking = atoi(argv[i]);
224 break;
225 case 'b':
226 param.probability = atoi(argv[i]);
227 break;
228 case 'q':
229 print_func = &print_null;
230 i--;
231 break;
232 case 'v':
233 cross_validation = 1;
234 nr_fold = atoi(argv[i]);
235 if(nr_fold < 2)
236 {
237 fprintf(stderr,"n-fold cross validation: n must >= 2\n");
238 exit_with_help();
239 }
240 break;
241 case 'w':
242 ++param.nr_weight;
243 param.weight_label = (int *)realloc(param.weight_label,sizeof(int)*param.nr_weight);
244 param.weight = (double *)realloc(param.weight,sizeof(double)*param.nr_weight);
245 param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);
246 param.weight[param.nr_weight-1] = atof(argv[i]);
247 break;
248 default:
249 fprintf(stderr,"Unknown option: -%c\n", argv[i-1][1]);
250 exit_with_help();
251 }
252 }
253
254 svm_set_print_string_function(print_func);
255
256 // determine filenames
257
258 if(i>=argc)
259 exit_with_help();
260
261 strcpy(input_file_name, argv[i]);
262
263 if(i<argc-1)
264 strcpy(model_file_name,argv[i+1]);
265 else
266 {
267 char *p = strrchr(argv[i],'/');
268 if(p==NULL)
269 p = argv[i];
270 else
271 ++p;
272 sprintf(model_file_name,"%s.model",p);
273 }
274 }
275
276 // read in a problem (in svmlight format)
277
read_problem(const char * filename)278 void read_problem(const char *filename)
279 {
280 int elements, max_index, inst_max_index, i, j;
281 FILE *fp = fopen(filename,"r");
282 char *endptr;
283 char *idx, *val, *label;
284
285 if(fp == NULL)
286 {
287 fprintf(stderr,"can't open input file %s\n",filename);
288 exit(1);
289 }
290
291 prob.l = 0;
292 elements = 0;
293
294 max_line_len = 1024;
295 line = Malloc(char,max_line_len);
296 while(readline(fp)!=NULL)
297 {
298 char *p = strtok(line," \t"); // label
299
300 // features
301 while(1)
302 {
303 p = strtok(NULL," \t");
304 if(p == NULL || *p == '\n') // check '\n' as ' ' may be after the last feature
305 break;
306 ++elements;
307 }
308 ++elements;
309 ++prob.l;
310 }
311 rewind(fp);
312
313 prob.y = Malloc(double,prob.l);
314 prob.x = Malloc(struct svm_node *,prob.l);
315 x_space = Malloc(struct svm_node,elements);
316
317 max_index = 0;
318 j=0;
319 for(i=0;i<prob.l;i++)
320 {
321 inst_max_index = -1; // strtol gives 0 if wrong format, and precomputed kernel has <index> start from 0
322 readline(fp);
323 prob.x[i] = &x_space[j];
324 label = strtok(line," \t\n");
325 if(label == NULL) // empty line
326 exit_input_error(i+1);
327
328 prob.y[i] = strtod(label,&endptr);
329 if(endptr == label || *endptr != '\0')
330 exit_input_error(i+1);
331
332 while(1)
333 {
334 idx = strtok(NULL,":");
335 val = strtok(NULL," \t");
336
337 if(val == NULL)
338 break;
339
340 errno = 0;
341 x_space[j].index = (int) strtol(idx,&endptr,10);
342 if(endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index)
343 exit_input_error(i+1);
344 else
345 inst_max_index = x_space[j].index;
346
347 errno = 0;
348 x_space[j].value = strtod(val,&endptr);
349 if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
350 exit_input_error(i+1);
351
352 ++j;
353 }
354
355 if(inst_max_index > max_index)
356 max_index = inst_max_index;
357 x_space[j++].index = -1;
358 }
359
360 if(param.gamma == 0 && max_index > 0)
361 param.gamma = 1.0/max_index;
362
363 if(param.kernel_type == PRECOMPUTED)
364 for(i=0;i<prob.l;i++)
365 {
366 if (prob.x[i][0].index != 0)
367 {
368 fprintf(stderr,"Wrong input format: first column must be 0:sample_serial_number\n");
369 exit(1);
370 }
371 if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)
372 {
373 fprintf(stderr,"Wrong input format: sample_serial_number out of range\n");
374 exit(1);
375 }
376 }
377
378 fclose(fp);
379 }
380