1 import libsvm.*; 2 import java.io.*; 3 import java.util.*; 4 5 class svm_predict { 6 private static svm_print_interface svm_print_null = new svm_print_interface() 7 { 8 public void print(String s) {} 9 }; 10 11 private static svm_print_interface svm_print_stdout = new svm_print_interface() 12 { 13 public void print(String s) 14 { 15 System.out.print(s); 16 } 17 }; 18 19 private static svm_print_interface svm_print_string = svm_print_stdout; 20 info(String s)21 static void info(String s) 22 { 23 svm_print_string.print(s); 24 } 25 atof(String s)26 private static double atof(String s) 27 { 28 return Double.valueOf(s).doubleValue(); 29 } 30 atoi(String s)31 private static int atoi(String s) 32 { 33 return Integer.parseInt(s); 34 } 35 predict(BufferedReader input, DataOutputStream output, svm_model model, int predict_probability)36 private static void predict(BufferedReader input, DataOutputStream output, svm_model model, int predict_probability) throws IOException 37 { 38 int correct = 0; 39 int total = 0; 40 double error = 0; 41 double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0; 42 43 int svm_type=svm.svm_get_svm_type(model); 44 int nr_class=svm.svm_get_nr_class(model); 45 double[] prob_estimates=null; 46 47 if(predict_probability == 1) 48 { 49 if(svm_type == svm_parameter.EPSILON_SVR || 50 svm_type == svm_parameter.NU_SVR) 51 { 52 svm_predict.info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma="+svm.svm_get_svr_probability(model)+"\n"); 53 } 54 else 55 { 56 int[] labels=new int[nr_class]; 57 svm.svm_get_labels(model,labels); 58 prob_estimates = new double[nr_class]; 59 output.writeBytes("labels"); 60 for(int j=0;j<nr_class;j++) 61 output.writeBytes(" "+labels[j]); 62 output.writeBytes("\n"); 63 } 64 } 65 while(true) 66 { 67 String line = input.readLine(); 68 if(line == null) break; 69 70 StringTokenizer st = new StringTokenizer(line," \t\n\r\f:"); 71 72 double target_label = atof(st.nextToken()); 73 int m = st.countTokens()/2; 74 svm_node[] x = new svm_node[m]; 75 for(int j=0;j<m;j++) 76 { 77 x[j] = new svm_node(); 78 x[j].index = atoi(st.nextToken()); 79 x[j].value = atof(st.nextToken()); 80 } 81 82 double predict_label; 83 if (predict_probability==1 && (svm_type==svm_parameter.C_SVC || svm_type==svm_parameter.NU_SVC)) 84 { 85 predict_label = svm.svm_predict_probability(model,x,prob_estimates); 86 output.writeBytes(predict_label+" "); 87 for(int j=0;j<nr_class;j++) 88 output.writeBytes(prob_estimates[j]+" "); 89 output.writeBytes("\n"); 90 } 91 else 92 { 93 predict_label = svm.svm_predict(model,x); 94 output.writeBytes(predict_label+"\n"); 95 } 96 97 if(predict_label == target_label) 98 ++correct; 99 error += (predict_label-target_label)*(predict_label-target_label); 100 sump += predict_label; 101 sumt += target_label; 102 sumpp += predict_label*predict_label; 103 sumtt += target_label*target_label; 104 sumpt += predict_label*target_label; 105 ++total; 106 } 107 if(svm_type == svm_parameter.EPSILON_SVR || 108 svm_type == svm_parameter.NU_SVR) 109 { 110 svm_predict.info("Mean squared error = "+error/total+" (regression)\n"); 111 svm_predict.info("Squared correlation coefficient = "+ 112 ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/ 113 ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt))+ 114 " (regression)\n"); 115 } 116 else 117 svm_predict.info("Accuracy = "+(double)correct/total*100+ 118 "% ("+correct+"/"+total+") (classification)\n"); 119 } 120 exit_with_help()121 private static void exit_with_help() 122 { 123 System.err.print("usage: svm_predict [options] test_file model_file output_file\n" 124 +"options:\n" 125 +"-b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n" 126 +"-q : quiet mode (no outputs)\n"); 127 System.exit(1); 128 } 129 main(String argv[])130 public static void main(String argv[]) throws IOException 131 { 132 int i, predict_probability=0; 133 svm_print_string = svm_print_stdout; 134 135 // parse options 136 for(i=0;i<argv.length;i++) 137 { 138 if(argv[i].charAt(0) != '-') break; 139 ++i; 140 switch(argv[i-1].charAt(1)) 141 { 142 case 'b': 143 predict_probability = atoi(argv[i]); 144 break; 145 case 'q': 146 svm_print_string = svm_print_null; 147 i--; 148 break; 149 default: 150 System.err.print("Unknown option: " + argv[i-1] + "\n"); 151 exit_with_help(); 152 } 153 } 154 if(i>=argv.length-2) 155 exit_with_help(); 156 try 157 { 158 BufferedReader input = new BufferedReader(new FileReader(argv[i])); 159 DataOutputStream output = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(argv[i+2]))); 160 svm_model model = svm.svm_load_model(argv[i+1]); 161 if (model == null) 162 { 163 System.err.print("can't open model file "+argv[i+1]+"\n"); 164 System.exit(1); 165 } 166 if(predict_probability == 1) 167 { 168 if(svm.svm_check_probability_model(model)==0) 169 { 170 System.err.print("Model does not support probabiliy estimates\n"); 171 System.exit(1); 172 } 173 } 174 else 175 { 176 if(svm.svm_check_probability_model(model)!=0) 177 { 178 svm_predict.info("Model supports probability estimates, but disabled in prediction.\n"); 179 } 180 } 181 predict(input,output,model,predict_probability); 182 input.close(); 183 output.close(); 184 } 185 catch(FileNotFoundException e) 186 { 187 exit_with_help(); 188 } 189 catch(ArrayIndexOutOfBoundsException e) 190 { 191 exit_with_help(); 192 } 193 } 194 } 195