1 #include <stdio.h>
2 #include <ctype.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <errno.h>
6 #include "svm.h"
7 
print_null(const char * s,...)8 int print_null(const char *s,...) {return 0;}
9 
10 static int (*info)(const char *fmt,...) = &printf;
11 
12 struct svm_node *x;
13 int max_nr_attr = 64;
14 
15 struct svm_model* model;
16 int predict_probability=0;
17 
18 static char *line = NULL;
19 static int max_line_len;
20 
readline(FILE * input)21 static char* readline(FILE *input)
22 {
23 	int len;
24 
25 	if(fgets(line,max_line_len,input) == NULL)
26 		return NULL;
27 
28 	while(strrchr(line,'\n') == NULL)
29 	{
30 		max_line_len *= 2;
31 		line = (char *) realloc(line,max_line_len);
32 		len = (int) strlen(line);
33 		if(fgets(line+len,max_line_len-len,input) == NULL)
34 			break;
35 	}
36 	return line;
37 }
38 
exit_input_error(int line_num)39 void exit_input_error(int line_num)
40 {
41 	fprintf(stderr,"Wrong input format at line %d\n", line_num);
42 	exit(1);
43 }
44 
predict(FILE * input,FILE * output)45 void predict(FILE *input, FILE *output)
46 {
47 	int correct = 0;
48 	int total = 0;
49 	double error = 0;
50 	double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0;
51 
52 	int svm_type=svm_get_svm_type(model);
53 	int nr_class=svm_get_nr_class(model);
54 	double *prob_estimates=NULL;
55 	int j;
56 
57 	if(predict_probability)
58 	{
59 		if (svm_type==NU_SVR || svm_type==EPSILON_SVR)
60 			info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",svm_get_svr_probability(model));
61 		else
62 		{
63 			int *labels=(int *) malloc(nr_class*sizeof(int));
64 			svm_get_labels(model,labels);
65 			prob_estimates = (double *) malloc(nr_class*sizeof(double));
66 			fprintf(output,"labels");
67 			for(j=0;j<nr_class;j++)
68 				fprintf(output," %d",labels[j]);
69 			fprintf(output,"\n");
70 			free(labels);
71 		}
72 	}
73 
74 	max_line_len = 1024;
75 	line = (char *)malloc(max_line_len*sizeof(char));
76 	while(readline(input) != NULL)
77 	{
78 		int i = 0;
79 		double target_label, predict_label;
80 		char *idx, *val, *label, *endptr;
81 		int inst_max_index = -1; // strtol gives 0 if wrong format, and precomputed kernel has <index> start from 0
82 
83 		label = strtok(line," \t\n");
84 		if(label == NULL) // empty line
85 			exit_input_error(total+1);
86 
87 		target_label = strtod(label,&endptr);
88 		if(endptr == label || *endptr != '\0')
89 			exit_input_error(total+1);
90 
91 		while(1)
92 		{
93 			if(i>=max_nr_attr-1)	// need one more for index = -1
94 			{
95 				max_nr_attr *= 2;
96 				x = (struct svm_node *) realloc(x,max_nr_attr*sizeof(struct svm_node));
97 			}
98 
99 			idx = strtok(NULL,":");
100 			val = strtok(NULL," \t");
101 
102 			if(val == NULL)
103 				break;
104 			errno = 0;
105 			x[i].index = (int) strtol(idx,&endptr,10);
106 			if(endptr == idx || errno != 0 || *endptr != '\0' || x[i].index <= inst_max_index)
107 				exit_input_error(total+1);
108 			else
109 				inst_max_index = x[i].index;
110 
111 			errno = 0;
112 			x[i].value = strtod(val,&endptr);
113 			if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
114 				exit_input_error(total+1);
115 
116 			++i;
117 		}
118 		x[i].index = -1;
119 
120 		if (predict_probability && (svm_type==C_SVC || svm_type==NU_SVC))
121 		{
122 			predict_label = svm_predict_probability(model,x,prob_estimates);
123 			fprintf(output,"%g",predict_label);
124 			for(j=0;j<nr_class;j++)
125 				fprintf(output," %g",prob_estimates[j]);
126 			fprintf(output,"\n");
127 		}
128 		else
129 		{
130 			predict_label = svm_predict(model,x);
131 			fprintf(output,"%.17g\n",predict_label);
132 		}
133 
134 		if(predict_label == target_label)
135 			++correct;
136 		error += (predict_label-target_label)*(predict_label-target_label);
137 		sump += predict_label;
138 		sumt += target_label;
139 		sumpp += predict_label*predict_label;
140 		sumtt += target_label*target_label;
141 		sumpt += predict_label*target_label;
142 		++total;
143 	}
144 	if (svm_type==NU_SVR || svm_type==EPSILON_SVR)
145 	{
146 		info("Mean squared error = %g (regression)\n",error/total);
147 		info("Squared correlation coefficient = %g (regression)\n",
148 			((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
149 			((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt))
150 			);
151 	}
152 	else
153 		info("Accuracy = %g%% (%d/%d) (classification)\n",
154 			(double)correct/total*100,correct,total);
155 	if(predict_probability)
156 		free(prob_estimates);
157 }
158 
exit_with_help()159 void exit_with_help()
160 {
161 	printf(
162 	"Usage: svm-predict [options] test_file model_file output_file\n"
163 	"options:\n"
164 	"-b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); for one-class SVM only 0 is supported\n"
165 	"-q : quiet mode (no outputs)\n"
166 	);
167 	exit(1);
168 }
169 
main(int argc,char ** argv)170 int main(int argc, char **argv)
171 {
172 	FILE *input, *output;
173 	int i;
174 	// parse options
175 	for(i=1;i<argc;i++)
176 	{
177 		if(argv[i][0] != '-') break;
178 		++i;
179 		switch(argv[i-1][1])
180 		{
181 			case 'b':
182 				predict_probability = atoi(argv[i]);
183 				break;
184 			case 'q':
185 				info = &print_null;
186 				i--;
187 				break;
188 			default:
189 				fprintf(stderr,"Unknown option: -%c\n", argv[i-1][1]);
190 				exit_with_help();
191 		}
192 	}
193 
194 	if(i>=argc-2)
195 		exit_with_help();
196 
197 	input = fopen(argv[i],"r");
198 	if(input == NULL)
199 	{
200 		fprintf(stderr,"can't open input file %s\n",argv[i]);
201 		exit(1);
202 	}
203 
204 	output = fopen(argv[i+2],"w");
205 	if(output == NULL)
206 	{
207 		fprintf(stderr,"can't open output file %s\n",argv[i+2]);
208 		exit(1);
209 	}
210 
211 	if((model=svm_load_model(argv[i+1]))==0)
212 	{
213 		fprintf(stderr,"can't open model file %s\n",argv[i+1]);
214 		exit(1);
215 	}
216 
217 	x = (struct svm_node *) malloc(max_nr_attr*sizeof(struct svm_node));
218 	if(predict_probability)
219 	{
220 		if(svm_check_probability_model(model)==0)
221 		{
222 			fprintf(stderr,"Model does not support probabiliy estimates\n");
223 			exit(1);
224 		}
225 	}
226 	else
227 	{
228 		if(svm_check_probability_model(model)!=0)
229 			info("Model supports probability estimates, but disabled in prediction.\n");
230 	}
231 
232 	predict(input,output);
233 	svm_free_and_destroy_model(&model);
234 	free(x);
235 	free(line);
236 	fclose(input);
237 	fclose(output);
238 	return 0;
239 }
240