1 import libsvm.*; 2 import java.io.*; 3 import java.util.*; 4 5 class svm_train { 6 private svm_parameter param; // set by parse_command_line 7 private svm_problem prob; // set by read_problem 8 private svm_model model; 9 private String input_file_name; // set by parse_command_line 10 private String model_file_name; // set by parse_command_line 11 private String error_msg; 12 private int cross_validation; 13 private int nr_fold; 14 15 private static svm_print_interface svm_print_null = new svm_print_interface() 16 { 17 public void print(String s) {} 18 }; 19 exit_with_help()20 private static void exit_with_help() 21 { 22 System.out.print( 23 "Usage: svm_train [options] training_set_file [model_file]\n" 24 +"options:\n" 25 +"-s svm_type : set type of SVM (default 0)\n" 26 +" 0 -- C-SVC (multi-class classification)\n" 27 +" 1 -- nu-SVC (multi-class classification)\n" 28 +" 2 -- one-class SVM\n" 29 +" 3 -- epsilon-SVR (regression)\n" 30 +" 4 -- nu-SVR (regression)\n" 31 +"-t kernel_type : set type of kernel function (default 2)\n" 32 +" 0 -- linear: u'*v\n" 33 +" 1 -- polynomial: (gamma*u'*v + coef0)^degree\n" 34 +" 2 -- radial basis function: exp(-gamma*|u-v|^2)\n" 35 +" 3 -- sigmoid: tanh(gamma*u'*v + coef0)\n" 36 +" 4 -- precomputed kernel (kernel values in training_set_file)\n" 37 +"-d degree : set degree in kernel function (default 3)\n" 38 +"-g gamma : set gamma in kernel function (default 1/num_features)\n" 39 +"-r coef0 : set coef0 in kernel function (default 0)\n" 40 +"-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)\n" 41 +"-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)\n" 42 +"-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)\n" 43 +"-m cachesize : set cache memory size in MB (default 100)\n" 44 +"-e epsilon : set tolerance of termination criterion (default 0.001)\n" 45 +"-h shrinking : whether to use the shrinking heuristics, 0 or 1 (default 1)\n" 46 +"-b probability_estimates : whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)\n" 47 +"-wi weight : set the parameter C of class i to weight*C, for C-SVC (default 1)\n" 48 +"-v n : n-fold cross validation mode\n" 49 +"-q : quiet mode (no outputs)\n" 50 ); 51 System.exit(1); 52 } 53 do_cross_validation()54 private void do_cross_validation() 55 { 56 int i; 57 int total_correct = 0; 58 double total_error = 0; 59 double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0; 60 double[] target = new double[prob.l]; 61 62 svm.svm_cross_validation(prob,param,nr_fold,target); 63 if(param.svm_type == svm_parameter.EPSILON_SVR || 64 param.svm_type == svm_parameter.NU_SVR) 65 { 66 for(i=0;i<prob.l;i++) 67 { 68 double y = prob.y[i]; 69 double v = target[i]; 70 total_error += (v-y)*(v-y); 71 sumv += v; 72 sumy += y; 73 sumvv += v*v; 74 sumyy += y*y; 75 sumvy += v*y; 76 } 77 System.out.print("Cross Validation Mean squared error = "+total_error/prob.l+"\n"); 78 System.out.print("Cross Validation Squared correlation coefficient = "+ 79 ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/ 80 ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))+"\n" 81 ); 82 } 83 else 84 { 85 for(i=0;i<prob.l;i++) 86 if(target[i] == prob.y[i]) 87 ++total_correct; 88 System.out.print("Cross Validation Accuracy = "+100.0*total_correct/prob.l+"%\n"); 89 } 90 } 91 run(String argv[])92 private void run(String argv[]) throws IOException 93 { 94 parse_command_line(argv); 95 read_problem(); 96 error_msg = svm.svm_check_parameter(prob,param); 97 98 if(error_msg != null) 99 { 100 System.err.print("ERROR: "+error_msg+"\n"); 101 System.exit(1); 102 } 103 104 if(cross_validation != 0) 105 { 106 do_cross_validation(); 107 } 108 else 109 { 110 model = svm.svm_train(prob,param); 111 svm.svm_save_model(model_file_name,model); 112 } 113 } 114 main(String argv[])115 public static void main(String argv[]) throws IOException 116 { 117 svm_train t = new svm_train(); 118 t.run(argv); 119 } 120 atof(String s)121 private static double atof(String s) 122 { 123 double d = Double.valueOf(s).doubleValue(); 124 if (Double.isNaN(d) || Double.isInfinite(d)) 125 { 126 System.err.print("NaN or Infinity in input\n"); 127 System.exit(1); 128 } 129 return(d); 130 } 131 atoi(String s)132 private static int atoi(String s) 133 { 134 return Integer.parseInt(s); 135 } 136 parse_command_line(String argv[])137 private void parse_command_line(String argv[]) 138 { 139 int i; 140 svm_print_interface print_func = null; // default printing to stdout 141 142 param = new svm_parameter(); 143 // default values 144 param.svm_type = svm_parameter.C_SVC; 145 param.kernel_type = svm_parameter.RBF; 146 param.degree = 3; 147 param.gamma = 0; // 1/num_features 148 param.coef0 = 0; 149 param.nu = 0.5; 150 param.cache_size = 100; 151 param.C = 1; 152 param.eps = 1e-3; 153 param.p = 0.1; 154 param.shrinking = 1; 155 param.probability = 0; 156 param.nr_weight = 0; 157 param.weight_label = new int[0]; 158 param.weight = new double[0]; 159 cross_validation = 0; 160 161 // parse options 162 for(i=0;i<argv.length;i++) 163 { 164 if(argv[i].charAt(0) != '-') break; 165 if(++i>=argv.length) 166 exit_with_help(); 167 switch(argv[i-1].charAt(1)) 168 { 169 case 's': 170 param.svm_type = atoi(argv[i]); 171 break; 172 case 't': 173 param.kernel_type = atoi(argv[i]); 174 break; 175 case 'd': 176 param.degree = atoi(argv[i]); 177 break; 178 case 'g': 179 param.gamma = atof(argv[i]); 180 break; 181 case 'r': 182 param.coef0 = atof(argv[i]); 183 break; 184 case 'n': 185 param.nu = atof(argv[i]); 186 break; 187 case 'm': 188 param.cache_size = atof(argv[i]); 189 break; 190 case 'c': 191 param.C = atof(argv[i]); 192 break; 193 case 'e': 194 param.eps = atof(argv[i]); 195 break; 196 case 'p': 197 param.p = atof(argv[i]); 198 break; 199 case 'h': 200 param.shrinking = atoi(argv[i]); 201 break; 202 case 'b': 203 param.probability = atoi(argv[i]); 204 break; 205 case 'q': 206 print_func = svm_print_null; 207 i--; 208 break; 209 case 'v': 210 cross_validation = 1; 211 nr_fold = atoi(argv[i]); 212 if(nr_fold < 2) 213 { 214 System.err.print("n-fold cross validation: n must >= 2\n"); 215 exit_with_help(); 216 } 217 break; 218 case 'w': 219 ++param.nr_weight; 220 { 221 int[] old = param.weight_label; 222 param.weight_label = new int[param.nr_weight]; 223 System.arraycopy(old,0,param.weight_label,0,param.nr_weight-1); 224 } 225 226 { 227 double[] old = param.weight; 228 param.weight = new double[param.nr_weight]; 229 System.arraycopy(old,0,param.weight,0,param.nr_weight-1); 230 } 231 232 param.weight_label[param.nr_weight-1] = atoi(argv[i-1].substring(2)); 233 param.weight[param.nr_weight-1] = atof(argv[i]); 234 break; 235 default: 236 System.err.print("Unknown option: " + argv[i-1] + "\n"); 237 exit_with_help(); 238 } 239 } 240 241 svm.svm_set_print_string_function(print_func); 242 243 // determine filenames 244 245 if(i>=argv.length) 246 exit_with_help(); 247 248 input_file_name = argv[i]; 249 250 if(i<argv.length-1) 251 model_file_name = argv[i+1]; 252 else 253 { 254 int p = argv[i].lastIndexOf('/'); 255 ++p; // whew... 256 model_file_name = argv[i].substring(p)+".model"; 257 } 258 } 259 260 // read in a problem (in svmlight format) 261 read_problem()262 private void read_problem() throws IOException 263 { 264 BufferedReader fp = new BufferedReader(new FileReader(input_file_name)); 265 Vector<Double> vy = new Vector<Double>(); 266 Vector<svm_node[]> vx = new Vector<svm_node[]>(); 267 int max_index = 0; 268 269 while(true) 270 { 271 String line = fp.readLine(); 272 if(line == null) break; 273 274 StringTokenizer st = new StringTokenizer(line," \t\n\r\f:"); 275 276 vy.addElement(atof(st.nextToken())); 277 int m = st.countTokens()/2; 278 svm_node[] x = new svm_node[m]; 279 for(int j=0;j<m;j++) 280 { 281 x[j] = new svm_node(); 282 x[j].index = atoi(st.nextToken()); 283 x[j].value = atof(st.nextToken()); 284 } 285 if(m>0) max_index = Math.max(max_index, x[m-1].index); 286 vx.addElement(x); 287 } 288 289 prob = new svm_problem(); 290 prob.l = vy.size(); 291 prob.x = new svm_node[prob.l][]; 292 for(int i=0;i<prob.l;i++) 293 prob.x[i] = vx.elementAt(i); 294 prob.y = new double[prob.l]; 295 for(int i=0;i<prob.l;i++) 296 prob.y[i] = vy.elementAt(i); 297 298 if(param.gamma == 0 && max_index > 0) 299 param.gamma = 1.0/max_index; 300 301 if(param.kernel_type == svm_parameter.PRECOMPUTED) 302 for(int i=0;i<prob.l;i++) 303 { 304 if (prob.x[i][0].index != 0) 305 { 306 System.err.print("Wrong kernel matrix: first column must be 0:sample_serial_number\n"); 307 System.exit(1); 308 } 309 if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index) 310 { 311 System.err.print("Wrong input format: sample_serial_number out of range\n"); 312 System.exit(1); 313 } 314 } 315 316 fp.close(); 317 } 318 } 319