1 import libsvm.*;
2 import java.io.*;
3 import java.util.*;
4 
5 class svm_predict {
6 	private static svm_print_interface svm_print_null = new svm_print_interface()
7 	{
8 		public void print(String s) {}
9 	};
10 
11 	private static svm_print_interface svm_print_stdout = new svm_print_interface()
12 	{
13 		public void print(String s)
14 		{
15 			System.out.print(s);
16 		}
17 	};
18 
19 	private static svm_print_interface svm_print_string = svm_print_stdout;
20 
info(String s)21 	static void info(String s)
22 	{
23 		svm_print_string.print(s);
24 	}
25 
atof(String s)26 	private static double atof(String s)
27 	{
28 		return Double.valueOf(s).doubleValue();
29 	}
30 
atoi(String s)31 	private static int atoi(String s)
32 	{
33 		return Integer.parseInt(s);
34 	}
35 
predict(BufferedReader input, DataOutputStream output, svm_model model, int predict_probability)36 	private static void predict(BufferedReader input, DataOutputStream output, svm_model model, int predict_probability) throws IOException
37 	{
38 		int correct = 0;
39 		int total = 0;
40 		double error = 0;
41 		double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0;
42 
43 		int svm_type=svm.svm_get_svm_type(model);
44 		int nr_class=svm.svm_get_nr_class(model);
45 		double[] prob_estimates=null;
46 
47 		if(predict_probability == 1)
48 		{
49 			if(svm_type == svm_parameter.EPSILON_SVR ||
50 			   svm_type == svm_parameter.NU_SVR)
51 			{
52 				svm_predict.info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma="+svm.svm_get_svr_probability(model)+"\n");
53 			}
54 			else
55 			{
56 				int[] labels=new int[nr_class];
57 				svm.svm_get_labels(model,labels);
58 				prob_estimates = new double[nr_class];
59 				output.writeBytes("labels");
60 				for(int j=0;j<nr_class;j++)
61 					output.writeBytes(" "+labels[j]);
62 				output.writeBytes("\n");
63 			}
64 		}
65 		while(true)
66 		{
67 			String line = input.readLine();
68 			if(line == null) break;
69 
70 			StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
71 
72 			double target_label = atof(st.nextToken());
73 			int m = st.countTokens()/2;
74 			svm_node[] x = new svm_node[m];
75 			for(int j=0;j<m;j++)
76 			{
77 				x[j] = new svm_node();
78 				x[j].index = atoi(st.nextToken());
79 				x[j].value = atof(st.nextToken());
80 			}
81 
82 			double predict_label;
83 			if (predict_probability==1 && (svm_type==svm_parameter.C_SVC || svm_type==svm_parameter.NU_SVC))
84 			{
85 				predict_label = svm.svm_predict_probability(model,x,prob_estimates);
86 				output.writeBytes(predict_label+" ");
87 				for(int j=0;j<nr_class;j++)
88 					output.writeBytes(prob_estimates[j]+" ");
89 				output.writeBytes("\n");
90 			}
91 			else
92 			{
93 				predict_label = svm.svm_predict(model,x);
94 				output.writeBytes(predict_label+"\n");
95 			}
96 
97 			if(predict_label == target_label)
98 				++correct;
99 			error += (predict_label-target_label)*(predict_label-target_label);
100 			sump += predict_label;
101 			sumt += target_label;
102 			sumpp += predict_label*predict_label;
103 			sumtt += target_label*target_label;
104 			sumpt += predict_label*target_label;
105 			++total;
106 		}
107 		if(svm_type == svm_parameter.EPSILON_SVR ||
108 		   svm_type == svm_parameter.NU_SVR)
109 		{
110 			svm_predict.info("Mean squared error = "+error/total+" (regression)\n");
111 			svm_predict.info("Squared correlation coefficient = "+
112 				 ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
113 				 ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt))+
114 				 " (regression)\n");
115 		}
116 		else
117 			svm_predict.info("Accuracy = "+(double)correct/total*100+
118 				 "% ("+correct+"/"+total+") (classification)\n");
119 	}
120 
exit_with_help()121 	private static void exit_with_help()
122 	{
123 		System.err.print("usage: svm_predict [options] test_file model_file output_file\n"
124 		+"options:\n"
125 		+"-b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n"
126 		+"-q : quiet mode (no outputs)\n");
127 		System.exit(1);
128 	}
129 
main(String argv[])130 	public static void main(String argv[]) throws IOException
131 	{
132 		int i, predict_probability=0;
133         	svm_print_string = svm_print_stdout;
134 
135 		// parse options
136 		for(i=0;i<argv.length;i++)
137 		{
138 			if(argv[i].charAt(0) != '-') break;
139 			++i;
140 			switch(argv[i-1].charAt(1))
141 			{
142 				case 'b':
143 					predict_probability = atoi(argv[i]);
144 					break;
145 				case 'q':
146 					svm_print_string = svm_print_null;
147 					i--;
148 					break;
149 				default:
150 					System.err.print("Unknown option: " + argv[i-1] + "\n");
151 					exit_with_help();
152 			}
153 		}
154 		if(i>=argv.length-2)
155 			exit_with_help();
156 		try
157 		{
158 			BufferedReader input = new BufferedReader(new FileReader(argv[i]));
159 			DataOutputStream output = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(argv[i+2])));
160 			svm_model model = svm.svm_load_model(argv[i+1]);
161 			if (model == null)
162 			{
163 				System.err.print("can't open model file "+argv[i+1]+"\n");
164 				System.exit(1);
165 			}
166 			if(predict_probability == 1)
167 			{
168 				if(svm.svm_check_probability_model(model)==0)
169 				{
170 					System.err.print("Model does not support probabiliy estimates\n");
171 					System.exit(1);
172 				}
173 			}
174 			else
175 			{
176 				if(svm.svm_check_probability_model(model)!=0)
177 				{
178 					svm_predict.info("Model supports probability estimates, but disabled in prediction.\n");
179 				}
180 			}
181 			predict(input,output,model,predict_probability);
182 			input.close();
183 			output.close();
184 		}
185 		catch(FileNotFoundException e)
186 		{
187 			exit_with_help();
188 		}
189 		catch(ArrayIndexOutOfBoundsException e)
190 		{
191 			exit_with_help();
192 		}
193 	}
194 }
195