1 #include <stdio.h>
2 #include <ctype.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <errno.h>
6 #include "svm.h"
7
print_null(const char * s,...)8 int print_null(const char *s,...) {return 0;}
9
10 static int (*info)(const char *fmt,...) = &printf;
11
12 struct svm_node *x;
13 int max_nr_attr = 64;
14
15 struct svm_model* model;
16 int predict_probability=0;
17
18 static char *line = NULL;
19 static int max_line_len;
20
readline(FILE * input)21 static char* readline(FILE *input)
22 {
23 int len;
24
25 if(fgets(line,max_line_len,input) == NULL)
26 return NULL;
27
28 while(strrchr(line,'\n') == NULL)
29 {
30 max_line_len *= 2;
31 line = (char *) realloc(line,max_line_len);
32 len = (int) strlen(line);
33 if(fgets(line+len,max_line_len-len,input) == NULL)
34 break;
35 }
36 return line;
37 }
38
exit_input_error(int line_num)39 void exit_input_error(int line_num)
40 {
41 fprintf(stderr,"Wrong input format at line %d\n", line_num);
42 exit(1);
43 }
44
predict(FILE * input,FILE * output)45 void predict(FILE *input, FILE *output)
46 {
47 int correct = 0;
48 int total = 0;
49 double error = 0;
50 double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0;
51
52 int svm_type=svm_get_svm_type(model);
53 int nr_class=svm_get_nr_class(model);
54 double *prob_estimates=NULL;
55 int j;
56
57 if(predict_probability)
58 {
59 if (svm_type==NU_SVR || svm_type==EPSILON_SVR)
60 info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",svm_get_svr_probability(model));
61 else
62 {
63 int *labels=(int *) malloc(nr_class*sizeof(int));
64 svm_get_labels(model,labels);
65 prob_estimates = (double *) malloc(nr_class*sizeof(double));
66 fprintf(output,"labels");
67 for(j=0;j<nr_class;j++)
68 fprintf(output," %d",labels[j]);
69 fprintf(output,"\n");
70 free(labels);
71 }
72 }
73
74 max_line_len = 1024;
75 line = (char *)malloc(max_line_len*sizeof(char));
76 while(readline(input) != NULL)
77 {
78 int i = 0;
79 double target_label, predict_label;
80 char *idx, *val, *label, *endptr;
81 int inst_max_index = -1; // strtol gives 0 if wrong format, and precomputed kernel has <index> start from 0
82
83 label = strtok(line," \t\n");
84 if(label == NULL) // empty line
85 exit_input_error(total+1);
86
87 target_label = strtod(label,&endptr);
88 if(endptr == label || *endptr != '\0')
89 exit_input_error(total+1);
90
91 while(1)
92 {
93 if(i>=max_nr_attr-1) // need one more for index = -1
94 {
95 max_nr_attr *= 2;
96 x = (struct svm_node *) realloc(x,max_nr_attr*sizeof(struct svm_node));
97 }
98
99 idx = strtok(NULL,":");
100 val = strtok(NULL," \t");
101
102 if(val == NULL)
103 break;
104 errno = 0;
105 x[i].index = (int) strtol(idx,&endptr,10);
106 if(endptr == idx || errno != 0 || *endptr != '\0' || x[i].index <= inst_max_index)
107 exit_input_error(total+1);
108 else
109 inst_max_index = x[i].index;
110
111 errno = 0;
112 x[i].value = strtod(val,&endptr);
113 if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
114 exit_input_error(total+1);
115
116 ++i;
117 }
118 x[i].index = -1;
119
120 if (predict_probability && (svm_type==C_SVC || svm_type==NU_SVC))
121 {
122 predict_label = svm_predict_probability(model,x,prob_estimates);
123 fprintf(output,"%g",predict_label);
124 for(j=0;j<nr_class;j++)
125 fprintf(output," %g",prob_estimates[j]);
126 fprintf(output,"\n");
127 }
128 else
129 {
130 predict_label = svm_predict(model,x);
131 fprintf(output,"%.17g\n",predict_label);
132 }
133
134 if(predict_label == target_label)
135 ++correct;
136 error += (predict_label-target_label)*(predict_label-target_label);
137 sump += predict_label;
138 sumt += target_label;
139 sumpp += predict_label*predict_label;
140 sumtt += target_label*target_label;
141 sumpt += predict_label*target_label;
142 ++total;
143 }
144 if (svm_type==NU_SVR || svm_type==EPSILON_SVR)
145 {
146 info("Mean squared error = %g (regression)\n",error/total);
147 info("Squared correlation coefficient = %g (regression)\n",
148 ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
149 ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt))
150 );
151 }
152 else
153 info("Accuracy = %g%% (%d/%d) (classification)\n",
154 (double)correct/total*100,correct,total);
155 if(predict_probability)
156 free(prob_estimates);
157 }
158
exit_with_help()159 void exit_with_help()
160 {
161 printf(
162 "Usage: svm-predict [options] test_file model_file output_file\n"
163 "options:\n"
164 "-b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); for one-class SVM only 0 is supported\n"
165 "-q : quiet mode (no outputs)\n"
166 );
167 exit(1);
168 }
169
main(int argc,char ** argv)170 int main(int argc, char **argv)
171 {
172 FILE *input, *output;
173 int i;
174 // parse options
175 for(i=1;i<argc;i++)
176 {
177 if(argv[i][0] != '-') break;
178 ++i;
179 switch(argv[i-1][1])
180 {
181 case 'b':
182 predict_probability = atoi(argv[i]);
183 break;
184 case 'q':
185 info = &print_null;
186 i--;
187 break;
188 default:
189 fprintf(stderr,"Unknown option: -%c\n", argv[i-1][1]);
190 exit_with_help();
191 }
192 }
193
194 if(i>=argc-2)
195 exit_with_help();
196
197 input = fopen(argv[i],"r");
198 if(input == NULL)
199 {
200 fprintf(stderr,"can't open input file %s\n",argv[i]);
201 exit(1);
202 }
203
204 output = fopen(argv[i+2],"w");
205 if(output == NULL)
206 {
207 fprintf(stderr,"can't open output file %s\n",argv[i+2]);
208 exit(1);
209 }
210
211 if((model=svm_load_model(argv[i+1]))==0)
212 {
213 fprintf(stderr,"can't open model file %s\n",argv[i+1]);
214 exit(1);
215 }
216
217 x = (struct svm_node *) malloc(max_nr_attr*sizeof(struct svm_node));
218 if(predict_probability)
219 {
220 if(svm_check_probability_model(model)==0)
221 {
222 fprintf(stderr,"Model does not support probabiliy estimates\n");
223 exit(1);
224 }
225 }
226 else
227 {
228 if(svm_check_probability_model(model)!=0)
229 info("Model supports probability estimates, but disabled in prediction.\n");
230 }
231
232 predict(input,output);
233 svm_free_and_destroy_model(&model);
234 free(x);
235 free(line);
236 fclose(input);
237 fclose(output);
238 return 0;
239 }
240