1 2 3 4 5 6 package libsvm; 7 import java.io.*; 8 import java.util.*; 9 10 // 11 // Kernel Cache 12 // 13 // l is the number of total data items 14 // size is the cache size limit in bytes 15 // 16 class Cache { 17 private final int l; 18 private long size; 19 private final class head_t 20 { 21 head_t prev, next; // a cicular list 22 float[] data; 23 int len; // data[0,len) is cached in this entry 24 } 25 private final head_t[] head; 26 private head_t lru_head; 27 Cache(int l_, long size_)28 Cache(int l_, long size_) 29 { 30 l = l_; 31 size = size_; 32 head = new head_t[l]; 33 for(int i=0;i<l;i++) head[i] = new head_t(); 34 size /= 4; 35 size -= l * (16/4); // sizeof(head_t) == 16 36 size = Math.max(size, 2* (long) l); // cache must be large enough for two columns 37 lru_head = new head_t(); 38 lru_head.next = lru_head.prev = lru_head; 39 } 40 lru_delete(head_t h)41 private void lru_delete(head_t h) 42 { 43 // delete from current location 44 h.prev.next = h.next; 45 h.next.prev = h.prev; 46 } 47 lru_insert(head_t h)48 private void lru_insert(head_t h) 49 { 50 // insert to last position 51 h.next = lru_head; 52 h.prev = lru_head.prev; 53 h.prev.next = h; 54 h.next.prev = h; 55 } 56 57 // request data [0,len) 58 // return some position p where [p,len) need to be filled 59 // (p >= len if nothing needs to be filled) 60 // java: simulate pointer using single-element array get_data(int index, float[][] data, int len)61 int get_data(int index, float[][] data, int len) 62 { 63 head_t h = head[index]; 64 if(h.len > 0) lru_delete(h); 65 int more = len - h.len; 66 67 if(more > 0) 68 { 69 // free old space 70 while(size < more) 71 { 72 head_t old = lru_head.next; 73 lru_delete(old); 74 size += old.len; 75 old.data = null; 76 old.len = 0; 77 } 78 79 // allocate new space 80 float[] new_data = new float[len]; 81 if(h.data != null) System.arraycopy(h.data,0,new_data,0,h.len); 82 h.data = new_data; 83 size -= more; 84 do {int tmp=h.len; h.len=len; len=tmp;} while(false); 85 } 86 87 lru_insert(h); 88 data[0] = h.data; 89 return len; 90 } 91 swap_index(int i, int j)92 void swap_index(int i, int j) 93 { 94 if(i==j) return; 95 96 if(head[i].len > 0) lru_delete(head[i]); 97 if(head[j].len > 0) lru_delete(head[j]); 98 do {float[] tmp=head[i].data; head[i].data=head[j].data; head[j].data=tmp;} while(false); 99 do {int tmp=head[i].len; head[i].len=head[j].len; head[j].len=tmp;} while(false); 100 if(head[i].len > 0) lru_insert(head[i]); 101 if(head[j].len > 0) lru_insert(head[j]); 102 103 if(i>j) do {int tmp=i; i=j; j=tmp;} while(false); 104 for(head_t h = lru_head.next; h!=lru_head; h=h.next) 105 { 106 if(h.len > i) 107 { 108 if(h.len > j) 109 do {float tmp=h.data[i]; h.data[i]=h.data[j]; h.data[j]=tmp;} while(false); 110 else 111 { 112 // give up 113 lru_delete(h); 114 size += h.len; 115 h.data = null; 116 h.len = 0; 117 } 118 } 119 } 120 } 121 } 122 123 // 124 // Kernel evaluation 125 // 126 // the static method k_function is for doing single kernel evaluation 127 // the constructor of Kernel prepares to calculate the l*l kernel matrix 128 // the member function get_Q is for getting one column from the Q Matrix 129 // 130 abstract class QMatrix { get_Q(int column, int len)131 abstract float[] get_Q(int column, int len); get_QD()132 abstract double[] get_QD(); swap_index(int i, int j)133 abstract void swap_index(int i, int j); 134 }; 135 136 abstract class Kernel extends QMatrix { 137 private svm_node[][] x; 138 private final double[] x_square; 139 140 // svm_parameter 141 private final int kernel_type; 142 private final int degree; 143 private final double gamma; 144 private final double coef0; 145 get_Q(int column, int len)146 abstract float[] get_Q(int column, int len); get_QD()147 abstract double[] get_QD(); 148 swap_index(int i, int j)149 void swap_index(int i, int j) 150 { 151 do {svm_node[] tmp=x[i]; x[i]=x[j]; x[j]=tmp;} while(false); 152 if(x_square != null) do {double tmp=x_square[i]; x_square[i]=x_square[j]; x_square[j]=tmp;} while(false); 153 } 154 powi(double base, int times)155 private static double powi(double base, int times) 156 { 157 double tmp = base, ret = 1.0; 158 159 for(int t=times; t>0; t/=2) 160 { 161 if(t%2==1) ret*=tmp; 162 tmp = tmp * tmp; 163 } 164 return ret; 165 } 166 kernel_function(int i, int j)167 double kernel_function(int i, int j) 168 { 169 switch(kernel_type) 170 { 171 case svm_parameter.LINEAR: 172 return dot(x[i],x[j]); 173 case svm_parameter.POLY: 174 return powi(gamma*dot(x[i],x[j])+coef0,degree); 175 case svm_parameter.RBF: 176 return Math.exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j]))); 177 case svm_parameter.SIGMOID: 178 return Math.tanh(gamma*dot(x[i],x[j])+coef0); 179 case svm_parameter.PRECOMPUTED: 180 return x[i][(int)(x[j][0].value)].value; 181 default: 182 return 0; // Unreachable 183 } 184 } 185 Kernel(int l, svm_node[][] x_, svm_parameter param)186 Kernel(int l, svm_node[][] x_, svm_parameter param) 187 { 188 this.kernel_type = param.kernel_type; 189 this.degree = param.degree; 190 this.gamma = param.gamma; 191 this.coef0 = param.coef0; 192 193 x = (svm_node[][])x_.clone(); 194 195 if(kernel_type == svm_parameter.RBF) 196 { 197 x_square = new double[l]; 198 for(int i=0;i<l;i++) 199 x_square[i] = dot(x[i],x[i]); 200 } 201 else 202 x_square = null; 203 } 204 dot(svm_node[] x, svm_node[] y)205 static double dot(svm_node[] x, svm_node[] y) 206 { 207 double sum = 0; 208 int xlen = x.length; 209 int ylen = y.length; 210 int i = 0; 211 int j = 0; 212 while(i < xlen && j < ylen) 213 { 214 if(x[i].index == y[j].index) 215 sum += x[i++].value * y[j++].value; 216 else 217 { 218 if(x[i].index > y[j].index) 219 ++j; 220 else 221 ++i; 222 } 223 } 224 return sum; 225 } 226 k_function(svm_node[] x, svm_node[] y, svm_parameter param)227 static double k_function(svm_node[] x, svm_node[] y, 228 svm_parameter param) 229 { 230 switch(param.kernel_type) 231 { 232 case svm_parameter.LINEAR: 233 return dot(x,y); 234 case svm_parameter.POLY: 235 return powi(param.gamma*dot(x,y)+param.coef0,param.degree); 236 case svm_parameter.RBF: 237 { 238 double sum = 0; 239 int xlen = x.length; 240 int ylen = y.length; 241 int i = 0; 242 int j = 0; 243 while(i < xlen && j < ylen) 244 { 245 if(x[i].index == y[j].index) 246 { 247 double d = x[i++].value - y[j++].value; 248 sum += d*d; 249 } 250 else if(x[i].index > y[j].index) 251 { 252 sum += y[j].value * y[j].value; 253 ++j; 254 } 255 else 256 { 257 sum += x[i].value * x[i].value; 258 ++i; 259 } 260 } 261 262 while(i < xlen) 263 { 264 sum += x[i].value * x[i].value; 265 ++i; 266 } 267 268 while(j < ylen) 269 { 270 sum += y[j].value * y[j].value; 271 ++j; 272 } 273 274 return Math.exp(-param.gamma*sum); 275 } 276 case svm_parameter.SIGMOID: 277 return Math.tanh(param.gamma*dot(x,y)+param.coef0); 278 case svm_parameter.PRECOMPUTED: //x: test (validation), y: SV 279 return x[(int)(y[0].value)].value; 280 default: 281 return 0; // Unreachable 282 } 283 } 284 } 285 286 // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918 287 // Solves: 288 // 289 // min 0.5(\alpha^T Q \alpha) + p^T \alpha 290 // 291 // y^T \alpha = \delta 292 // y_i = +1 or -1 293 // 0 <= alpha_i <= Cp for y_i = 1 294 // 0 <= alpha_i <= Cn for y_i = -1 295 // 296 // Given: 297 // 298 // Q, p, y, Cp, Cn, and an initial feasible point \alpha 299 // l is the size of vectors and matrices 300 // eps is the stopping tolerance 301 // 302 // solution will be put in \alpha, objective value will be put in obj 303 // 304 class Solver { 305 int active_size; 306 byte[] y; 307 double[] G; // gradient of objective function 308 static final byte LOWER_BOUND = 0; 309 static final byte UPPER_BOUND = 1; 310 static final byte FREE = 2; 311 byte[] alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE 312 double[] alpha; 313 QMatrix Q; 314 double[] QD; 315 double eps; 316 double Cp,Cn; 317 double[] p; 318 int[] active_set; 319 double[] G_bar; // gradient, if we treat free variables as 0 320 int l; 321 boolean unshrink; // XXX 322 323 static final double INF = java.lang.Double.POSITIVE_INFINITY; 324 get_C(int i)325 double get_C(int i) 326 { 327 return (y[i] > 0)? Cp : Cn; 328 } update_alpha_status(int i)329 void update_alpha_status(int i) 330 { 331 if(alpha[i] >= get_C(i)) 332 alpha_status[i] = UPPER_BOUND; 333 else if(alpha[i] <= 0) 334 alpha_status[i] = LOWER_BOUND; 335 else alpha_status[i] = FREE; 336 } is_upper_bound(int i)337 boolean is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; } is_lower_bound(int i)338 boolean is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; } is_free(int i)339 boolean is_free(int i) { return alpha_status[i] == FREE; } 340 341 // java: information about solution except alpha, 342 // because we cannot return multiple values otherwise... 343 static class SolutionInfo { 344 double obj; 345 double rho; 346 double upper_bound_p; 347 double upper_bound_n; 348 double r; // for Solver_NU 349 } 350 swap_index(int i, int j)351 void swap_index(int i, int j) 352 { 353 Q.swap_index(i,j); 354 do {byte tmp=y[i]; y[i]=y[j]; y[j]=tmp;} while(false); 355 do {double tmp=G[i]; G[i]=G[j]; G[j]=tmp;} while(false); 356 do {byte tmp=alpha_status[i]; alpha_status[i]=alpha_status[j]; alpha_status[j]=tmp;} while(false); 357 do {double tmp=alpha[i]; alpha[i]=alpha[j]; alpha[j]=tmp;} while(false); 358 do {double tmp=p[i]; p[i]=p[j]; p[j]=tmp;} while(false); 359 do {int tmp=active_set[i]; active_set[i]=active_set[j]; active_set[j]=tmp;} while(false); 360 do {double tmp=G_bar[i]; G_bar[i]=G_bar[j]; G_bar[j]=tmp;} while(false); 361 } 362 reconstruct_gradient()363 void reconstruct_gradient() 364 { 365 // reconstruct inactive elements of G from G_bar and free variables 366 367 if(active_size == l) return; 368 369 int i,j; 370 int nr_free = 0; 371 372 for(j=active_size;j<l;j++) 373 G[j] = G_bar[j] + p[j]; 374 375 for(j=0;j<active_size;j++) 376 if(is_free(j)) 377 nr_free++; 378 379 if(2*nr_free < active_size) 380 svm.info("\nWARNING: using -h 0 may be faster\n"); 381 382 if (nr_free*l > 2*active_size*(l-active_size)) 383 { 384 for(i=active_size;i<l;i++) 385 { 386 float[] Q_i = Q.get_Q(i,active_size); 387 for(j=0;j<active_size;j++) 388 if(is_free(j)) 389 G[i] += alpha[j] * Q_i[j]; 390 } 391 } 392 else 393 { 394 for(i=0;i<active_size;i++) 395 if(is_free(i)) 396 { 397 float[] Q_i = Q.get_Q(i,l); 398 double alpha_i = alpha[i]; 399 for(j=active_size;j<l;j++) 400 G[j] += alpha_i * Q_i[j]; 401 } 402 } 403 } 404 Solve(int l, QMatrix Q, double[] p_, byte[] y_, double[] alpha_, double Cp, double Cn, double eps, SolutionInfo si, int shrinking)405 void Solve(int l, QMatrix Q, double[] p_, byte[] y_, 406 double[] alpha_, double Cp, double Cn, double eps, SolutionInfo si, int shrinking) 407 { 408 this.l = l; 409 this.Q = Q; 410 QD = Q.get_QD(); 411 p = (double[])p_.clone(); 412 y = (byte[])y_.clone(); 413 alpha = (double[])alpha_.clone(); 414 this.Cp = Cp; 415 this.Cn = Cn; 416 this.eps = eps; 417 this.unshrink = false; 418 419 // initialize alpha_status 420 { 421 alpha_status = new byte[l]; 422 for(int i=0;i<l;i++) 423 update_alpha_status(i); 424 } 425 426 // initialize active set (for shrinking) 427 { 428 active_set = new int[l]; 429 for(int i=0;i<l;i++) 430 active_set[i] = i; 431 active_size = l; 432 } 433 434 // initialize gradient 435 { 436 G = new double[l]; 437 G_bar = new double[l]; 438 int i; 439 for(i=0;i<l;i++) 440 { 441 G[i] = p[i]; 442 G_bar[i] = 0; 443 } 444 for(i=0;i<l;i++) 445 if(!is_lower_bound(i)) 446 { 447 float[] Q_i = Q.get_Q(i,l); 448 double alpha_i = alpha[i]; 449 int j; 450 for(j=0;j<l;j++) 451 G[j] += alpha_i*Q_i[j]; 452 if(is_upper_bound(i)) 453 for(j=0;j<l;j++) 454 G_bar[j] += get_C(i) * Q_i[j]; 455 } 456 } 457 458 // optimization step 459 460 int iter = 0; 461 int max_iter = Math.max(10000000, l>Integer.MAX_VALUE/100 ? Integer.MAX_VALUE : 100*l); 462 int counter = Math.min(l,1000)+1; 463 int[] working_set = new int[2]; 464 465 while(iter < max_iter) 466 { 467 // show progress and do shrinking 468 469 if(--counter == 0) 470 { 471 counter = Math.min(l,1000); 472 if(shrinking!=0) do_shrinking(); 473 svm.info("."); 474 } 475 476 if(select_working_set(working_set)!=0) 477 { 478 // reconstruct the whole gradient 479 reconstruct_gradient(); 480 // reset active set size and check 481 active_size = l; 482 svm.info("*"); 483 if(select_working_set(working_set)!=0) 484 break; 485 else 486 counter = 1; // do shrinking next iteration 487 } 488 489 int i = working_set[0]; 490 int j = working_set[1]; 491 492 ++iter; 493 494 // update alpha[i] and alpha[j], handle bounds carefully 495 496 float[] Q_i = Q.get_Q(i,active_size); 497 float[] Q_j = Q.get_Q(j,active_size); 498 499 double C_i = get_C(i); 500 double C_j = get_C(j); 501 502 double old_alpha_i = alpha[i]; 503 double old_alpha_j = alpha[j]; 504 505 if(y[i]!=y[j]) 506 { 507 double quad_coef = QD[i]+QD[j]+2*Q_i[j]; 508 if (quad_coef <= 0) 509 quad_coef = 1e-12; 510 double delta = (-G[i]-G[j])/quad_coef; 511 double diff = alpha[i] - alpha[j]; 512 alpha[i] += delta; 513 alpha[j] += delta; 514 515 if(diff > 0) 516 { 517 if(alpha[j] < 0) 518 { 519 alpha[j] = 0; 520 alpha[i] = diff; 521 } 522 } 523 else 524 { 525 if(alpha[i] < 0) 526 { 527 alpha[i] = 0; 528 alpha[j] = -diff; 529 } 530 } 531 if(diff > C_i - C_j) 532 { 533 if(alpha[i] > C_i) 534 { 535 alpha[i] = C_i; 536 alpha[j] = C_i - diff; 537 } 538 } 539 else 540 { 541 if(alpha[j] > C_j) 542 { 543 alpha[j] = C_j; 544 alpha[i] = C_j + diff; 545 } 546 } 547 } 548 else 549 { 550 double quad_coef = QD[i]+QD[j]-2*Q_i[j]; 551 if (quad_coef <= 0) 552 quad_coef = 1e-12; 553 double delta = (G[i]-G[j])/quad_coef; 554 double sum = alpha[i] + alpha[j]; 555 alpha[i] -= delta; 556 alpha[j] += delta; 557 558 if(sum > C_i) 559 { 560 if(alpha[i] > C_i) 561 { 562 alpha[i] = C_i; 563 alpha[j] = sum - C_i; 564 } 565 } 566 else 567 { 568 if(alpha[j] < 0) 569 { 570 alpha[j] = 0; 571 alpha[i] = sum; 572 } 573 } 574 if(sum > C_j) 575 { 576 if(alpha[j] > C_j) 577 { 578 alpha[j] = C_j; 579 alpha[i] = sum - C_j; 580 } 581 } 582 else 583 { 584 if(alpha[i] < 0) 585 { 586 alpha[i] = 0; 587 alpha[j] = sum; 588 } 589 } 590 } 591 592 // update G 593 594 double delta_alpha_i = alpha[i] - old_alpha_i; 595 double delta_alpha_j = alpha[j] - old_alpha_j; 596 597 for(int k=0;k<active_size;k++) 598 { 599 G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j; 600 } 601 602 // update alpha_status and G_bar 603 604 { 605 boolean ui = is_upper_bound(i); 606 boolean uj = is_upper_bound(j); 607 update_alpha_status(i); 608 update_alpha_status(j); 609 int k; 610 if(ui != is_upper_bound(i)) 611 { 612 Q_i = Q.get_Q(i,l); 613 if(ui) 614 for(k=0;k<l;k++) 615 G_bar[k] -= C_i * Q_i[k]; 616 else 617 for(k=0;k<l;k++) 618 G_bar[k] += C_i * Q_i[k]; 619 } 620 621 if(uj != is_upper_bound(j)) 622 { 623 Q_j = Q.get_Q(j,l); 624 if(uj) 625 for(k=0;k<l;k++) 626 G_bar[k] -= C_j * Q_j[k]; 627 else 628 for(k=0;k<l;k++) 629 G_bar[k] += C_j * Q_j[k]; 630 } 631 } 632 633 } 634 635 if(iter >= max_iter) 636 { 637 if(active_size < l) 638 { 639 // reconstruct the whole gradient to calculate objective value 640 reconstruct_gradient(); 641 active_size = l; 642 svm.info("*"); 643 } 644 System.err.print("\nWARNING: reaching max number of iterations\n"); 645 } 646 647 // calculate rho 648 649 si.rho = calculate_rho(); 650 651 // calculate objective value 652 { 653 double v = 0; 654 int i; 655 for(i=0;i<l;i++) 656 v += alpha[i] * (G[i] + p[i]); 657 658 si.obj = v/2; 659 } 660 661 // put back the solution 662 { 663 for(int i=0;i<l;i++) 664 alpha_[active_set[i]] = alpha[i]; 665 } 666 667 si.upper_bound_p = Cp; 668 si.upper_bound_n = Cn; 669 670 svm.info("\noptimization finished, #iter = "+iter+"\n"); 671 } 672 673 // return 1 if already optimal, return 0 otherwise select_working_set(int[] working_set)674 int select_working_set(int[] working_set) 675 { 676 // return i,j such that 677 // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha) 678 // j: minimizes the decrease of obj value 679 // (if quadratic coefficeint <= 0, replace it with tau) 680 // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha) 681 682 double Gmax = -INF; 683 double Gmax2 = -INF; 684 int Gmax_idx = -1; 685 int Gmin_idx = -1; 686 double obj_diff_min = INF; 687 688 for(int t=0;t<active_size;t++) 689 if(y[t]==+1) 690 { 691 if(!is_upper_bound(t)) 692 if(-G[t] >= Gmax) 693 { 694 Gmax = -G[t]; 695 Gmax_idx = t; 696 } 697 } 698 else 699 { 700 if(!is_lower_bound(t)) 701 if(G[t] >= Gmax) 702 { 703 Gmax = G[t]; 704 Gmax_idx = t; 705 } 706 } 707 708 int i = Gmax_idx; 709 float[] Q_i = null; 710 if(i != -1) // null Q_i not accessed: Gmax=-INF if i=-1 711 Q_i = Q.get_Q(i,active_size); 712 713 for(int j=0;j<active_size;j++) 714 { 715 if(y[j]==+1) 716 { 717 if (!is_lower_bound(j)) 718 { 719 double grad_diff=Gmax+G[j]; 720 if (G[j] >= Gmax2) 721 Gmax2 = G[j]; 722 if (grad_diff > 0) 723 { 724 double obj_diff; 725 double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j]; 726 if (quad_coef > 0) 727 obj_diff = -(grad_diff*grad_diff)/quad_coef; 728 else 729 obj_diff = -(grad_diff*grad_diff)/1e-12; 730 731 if (obj_diff <= obj_diff_min) 732 { 733 Gmin_idx=j; 734 obj_diff_min = obj_diff; 735 } 736 } 737 } 738 } 739 else 740 { 741 if (!is_upper_bound(j)) 742 { 743 double grad_diff= Gmax-G[j]; 744 if (-G[j] >= Gmax2) 745 Gmax2 = -G[j]; 746 if (grad_diff > 0) 747 { 748 double obj_diff; 749 double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j]; 750 if (quad_coef > 0) 751 obj_diff = -(grad_diff*grad_diff)/quad_coef; 752 else 753 obj_diff = -(grad_diff*grad_diff)/1e-12; 754 755 if (obj_diff <= obj_diff_min) 756 { 757 Gmin_idx=j; 758 obj_diff_min = obj_diff; 759 } 760 } 761 } 762 } 763 } 764 765 if(Gmax+Gmax2 < eps || Gmin_idx == -1) 766 return 1; 767 768 working_set[0] = Gmax_idx; 769 working_set[1] = Gmin_idx; 770 return 0; 771 } 772 be_shrunk(int i, double Gmax1, double Gmax2)773 private boolean be_shrunk(int i, double Gmax1, double Gmax2) 774 { 775 if(is_upper_bound(i)) 776 { 777 if(y[i]==+1) 778 return(-G[i] > Gmax1); 779 else 780 return(-G[i] > Gmax2); 781 } 782 else if(is_lower_bound(i)) 783 { 784 if(y[i]==+1) 785 return(G[i] > Gmax2); 786 else 787 return(G[i] > Gmax1); 788 } 789 else 790 return(false); 791 } 792 do_shrinking()793 void do_shrinking() 794 { 795 int i; 796 double Gmax1 = -INF; // max { -y_i * grad(f)_i | i in I_up(\alpha) } 797 double Gmax2 = -INF; // max { y_i * grad(f)_i | i in I_low(\alpha) } 798 799 // find maximal violating pair first 800 for(i=0;i<active_size;i++) 801 { 802 if(y[i]==+1) 803 { 804 if(!is_upper_bound(i)) 805 { 806 if(-G[i] >= Gmax1) 807 Gmax1 = -G[i]; 808 } 809 if(!is_lower_bound(i)) 810 { 811 if(G[i] >= Gmax2) 812 Gmax2 = G[i]; 813 } 814 } 815 else 816 { 817 if(!is_upper_bound(i)) 818 { 819 if(-G[i] >= Gmax2) 820 Gmax2 = -G[i]; 821 } 822 if(!is_lower_bound(i)) 823 { 824 if(G[i] >= Gmax1) 825 Gmax1 = G[i]; 826 } 827 } 828 } 829 830 if(unshrink == false && Gmax1 + Gmax2 <= eps*10) 831 { 832 unshrink = true; 833 reconstruct_gradient(); 834 active_size = l; 835 svm.info("*"); 836 } 837 838 for(i=0;i<active_size;i++) 839 if (be_shrunk(i, Gmax1, Gmax2)) 840 { 841 active_size--; 842 while (active_size > i) 843 { 844 if (!be_shrunk(active_size, Gmax1, Gmax2)) 845 { 846 swap_index(i,active_size); 847 break; 848 } 849 active_size--; 850 } 851 } 852 } 853 calculate_rho()854 double calculate_rho() 855 { 856 double r; 857 int nr_free = 0; 858 double ub = INF, lb = -INF, sum_free = 0; 859 for(int i=0;i<active_size;i++) 860 { 861 double yG = y[i]*G[i]; 862 863 if(is_upper_bound(i)) 864 { 865 if(y[i] < 0) 866 ub = Math.min(ub,yG); 867 else 868 lb = Math.max(lb,yG); 869 } 870 else if(is_lower_bound(i)) 871 { 872 if(y[i] > 0) 873 ub = Math.min(ub,yG); 874 else 875 lb = Math.max(lb,yG); 876 } 877 else 878 { 879 ++nr_free; 880 sum_free += yG; 881 } 882 } 883 884 if(nr_free>0) 885 r = sum_free/nr_free; 886 else 887 r = (ub+lb)/2; 888 889 return r; 890 } 891 892 } 893 894 // 895 // Solver for nu-svm classification and regression 896 // 897 // additional constraint: e^T \alpha = constant 898 // 899 final class Solver_NU extends Solver 900 { 901 private SolutionInfo si; 902 Solve(int l, QMatrix Q, double[] p, byte[] y, double[] alpha, double Cp, double Cn, double eps, SolutionInfo si, int shrinking)903 void Solve(int l, QMatrix Q, double[] p, byte[] y, 904 double[] alpha, double Cp, double Cn, double eps, 905 SolutionInfo si, int shrinking) 906 { 907 this.si = si; 908 super.Solve(l,Q,p,y,alpha,Cp,Cn,eps,si,shrinking); 909 } 910 911 // return 1 if already optimal, return 0 otherwise select_working_set(int[] working_set)912 int select_working_set(int[] working_set) 913 { 914 // return i,j such that y_i = y_j and 915 // i: maximizes -y_i * grad(f)_i, i in I_up(\alpha) 916 // j: minimizes the decrease of obj value 917 // (if quadratic coefficeint <= 0, replace it with tau) 918 // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha) 919 920 double Gmaxp = -INF; 921 double Gmaxp2 = -INF; 922 int Gmaxp_idx = -1; 923 924 double Gmaxn = -INF; 925 double Gmaxn2 = -INF; 926 int Gmaxn_idx = -1; 927 928 int Gmin_idx = -1; 929 double obj_diff_min = INF; 930 931 for(int t=0;t<active_size;t++) 932 if(y[t]==+1) 933 { 934 if(!is_upper_bound(t)) 935 if(-G[t] >= Gmaxp) 936 { 937 Gmaxp = -G[t]; 938 Gmaxp_idx = t; 939 } 940 } 941 else 942 { 943 if(!is_lower_bound(t)) 944 if(G[t] >= Gmaxn) 945 { 946 Gmaxn = G[t]; 947 Gmaxn_idx = t; 948 } 949 } 950 951 int ip = Gmaxp_idx; 952 int in = Gmaxn_idx; 953 float[] Q_ip = null; 954 float[] Q_in = null; 955 if(ip != -1) // null Q_ip not accessed: Gmaxp=-INF if ip=-1 956 Q_ip = Q.get_Q(ip,active_size); 957 if(in != -1) 958 Q_in = Q.get_Q(in,active_size); 959 960 for(int j=0;j<active_size;j++) 961 { 962 if(y[j]==+1) 963 { 964 if (!is_lower_bound(j)) 965 { 966 double grad_diff=Gmaxp+G[j]; 967 if (G[j] >= Gmaxp2) 968 Gmaxp2 = G[j]; 969 if (grad_diff > 0) 970 { 971 double obj_diff; 972 double quad_coef = QD[ip]+QD[j]-2*Q_ip[j]; 973 if (quad_coef > 0) 974 obj_diff = -(grad_diff*grad_diff)/quad_coef; 975 else 976 obj_diff = -(grad_diff*grad_diff)/1e-12; 977 978 if (obj_diff <= obj_diff_min) 979 { 980 Gmin_idx=j; 981 obj_diff_min = obj_diff; 982 } 983 } 984 } 985 } 986 else 987 { 988 if (!is_upper_bound(j)) 989 { 990 double grad_diff=Gmaxn-G[j]; 991 if (-G[j] >= Gmaxn2) 992 Gmaxn2 = -G[j]; 993 if (grad_diff > 0) 994 { 995 double obj_diff; 996 double quad_coef = QD[in]+QD[j]-2*Q_in[j]; 997 if (quad_coef > 0) 998 obj_diff = -(grad_diff*grad_diff)/quad_coef; 999 else 1000 obj_diff = -(grad_diff*grad_diff)/1e-12; 1001 1002 if (obj_diff <= obj_diff_min) 1003 { 1004 Gmin_idx=j; 1005 obj_diff_min = obj_diff; 1006 } 1007 } 1008 } 1009 } 1010 } 1011 1012 if(Math.max(Gmaxp+Gmaxp2,Gmaxn+Gmaxn2) < eps || Gmin_idx == -1) 1013 return 1; 1014 1015 if(y[Gmin_idx] == +1) 1016 working_set[0] = Gmaxp_idx; 1017 else 1018 working_set[0] = Gmaxn_idx; 1019 working_set[1] = Gmin_idx; 1020 1021 return 0; 1022 } 1023 be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4)1024 private boolean be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, double Gmax4) 1025 { 1026 if(is_upper_bound(i)) 1027 { 1028 if(y[i]==+1) 1029 return(-G[i] > Gmax1); 1030 else 1031 return(-G[i] > Gmax4); 1032 } 1033 else if(is_lower_bound(i)) 1034 { 1035 if(y[i]==+1) 1036 return(G[i] > Gmax2); 1037 else 1038 return(G[i] > Gmax3); 1039 } 1040 else 1041 return(false); 1042 } 1043 do_shrinking()1044 void do_shrinking() 1045 { 1046 double Gmax1 = -INF; // max { -y_i * grad(f)_i | y_i = +1, i in I_up(\alpha) } 1047 double Gmax2 = -INF; // max { y_i * grad(f)_i | y_i = +1, i in I_low(\alpha) } 1048 double Gmax3 = -INF; // max { -y_i * grad(f)_i | y_i = -1, i in I_up(\alpha) } 1049 double Gmax4 = -INF; // max { y_i * grad(f)_i | y_i = -1, i in I_low(\alpha) } 1050 1051 // find maximal violating pair first 1052 int i; 1053 for(i=0;i<active_size;i++) 1054 { 1055 if(!is_upper_bound(i)) 1056 { 1057 if(y[i]==+1) 1058 { 1059 if(-G[i] > Gmax1) Gmax1 = -G[i]; 1060 } 1061 else if(-G[i] > Gmax4) Gmax4 = -G[i]; 1062 } 1063 if(!is_lower_bound(i)) 1064 { 1065 if(y[i]==+1) 1066 { 1067 if(G[i] > Gmax2) Gmax2 = G[i]; 1068 } 1069 else if(G[i] > Gmax3) Gmax3 = G[i]; 1070 } 1071 } 1072 1073 if(unshrink == false && Math.max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10) 1074 { 1075 unshrink = true; 1076 reconstruct_gradient(); 1077 active_size = l; 1078 } 1079 1080 for(i=0;i<active_size;i++) 1081 if (be_shrunk(i, Gmax1, Gmax2, Gmax3, Gmax4)) 1082 { 1083 active_size--; 1084 while (active_size > i) 1085 { 1086 if (!be_shrunk(active_size, Gmax1, Gmax2, Gmax3, Gmax4)) 1087 { 1088 swap_index(i,active_size); 1089 break; 1090 } 1091 active_size--; 1092 } 1093 } 1094 } 1095 calculate_rho()1096 double calculate_rho() 1097 { 1098 int nr_free1 = 0,nr_free2 = 0; 1099 double ub1 = INF, ub2 = INF; 1100 double lb1 = -INF, lb2 = -INF; 1101 double sum_free1 = 0, sum_free2 = 0; 1102 1103 for(int i=0;i<active_size;i++) 1104 { 1105 if(y[i]==+1) 1106 { 1107 if(is_upper_bound(i)) 1108 lb1 = Math.max(lb1,G[i]); 1109 else if(is_lower_bound(i)) 1110 ub1 = Math.min(ub1,G[i]); 1111 else 1112 { 1113 ++nr_free1; 1114 sum_free1 += G[i]; 1115 } 1116 } 1117 else 1118 { 1119 if(is_upper_bound(i)) 1120 lb2 = Math.max(lb2,G[i]); 1121 else if(is_lower_bound(i)) 1122 ub2 = Math.min(ub2,G[i]); 1123 else 1124 { 1125 ++nr_free2; 1126 sum_free2 += G[i]; 1127 } 1128 } 1129 } 1130 1131 double r1,r2; 1132 if(nr_free1 > 0) 1133 r1 = sum_free1/nr_free1; 1134 else 1135 r1 = (ub1+lb1)/2; 1136 1137 if(nr_free2 > 0) 1138 r2 = sum_free2/nr_free2; 1139 else 1140 r2 = (ub2+lb2)/2; 1141 1142 si.r = (r1+r2)/2; 1143 return (r1-r2)/2; 1144 } 1145 } 1146 1147 // 1148 // Q matrices for various formulations 1149 // 1150 class SVC_Q extends Kernel 1151 { 1152 private final byte[] y; 1153 private final Cache cache; 1154 private final double[] QD; 1155 SVC_Q(svm_problem prob, svm_parameter param, byte[] y_)1156 SVC_Q(svm_problem prob, svm_parameter param, byte[] y_) 1157 { 1158 super(prob.l, prob.x, param); 1159 y = (byte[])y_.clone(); 1160 cache = new Cache(prob.l,(long)(param.cache_size*(1<<20))); 1161 QD = new double[prob.l]; 1162 for(int i=0;i<prob.l;i++) 1163 QD[i] = kernel_function(i,i); 1164 } 1165 get_Q(int i, int len)1166 float[] get_Q(int i, int len) 1167 { 1168 float[][] data = new float[1][]; 1169 int start, j; 1170 if((start = cache.get_data(i,data,len)) < len) 1171 { 1172 for(j=start;j<len;j++) 1173 data[0][j] = (float)(y[i]*y[j]*kernel_function(i,j)); 1174 } 1175 return data[0]; 1176 } 1177 get_QD()1178 double[] get_QD() 1179 { 1180 return QD; 1181 } 1182 swap_index(int i, int j)1183 void swap_index(int i, int j) 1184 { 1185 cache.swap_index(i,j); 1186 super.swap_index(i,j); 1187 do {byte tmp=y[i]; y[i]=y[j]; y[j]=tmp;} while(false); 1188 do {double tmp=QD[i]; QD[i]=QD[j]; QD[j]=tmp;} while(false); 1189 } 1190 } 1191 1192 class ONE_CLASS_Q extends Kernel 1193 { 1194 private final Cache cache; 1195 private final double[] QD; 1196 ONE_CLASS_Q(svm_problem prob, svm_parameter param)1197 ONE_CLASS_Q(svm_problem prob, svm_parameter param) 1198 { 1199 super(prob.l, prob.x, param); 1200 cache = new Cache(prob.l,(long)(param.cache_size*(1<<20))); 1201 QD = new double[prob.l]; 1202 for(int i=0;i<prob.l;i++) 1203 QD[i] = kernel_function(i,i); 1204 } 1205 get_Q(int i, int len)1206 float[] get_Q(int i, int len) 1207 { 1208 float[][] data = new float[1][]; 1209 int start, j; 1210 if((start = cache.get_data(i,data,len)) < len) 1211 { 1212 for(j=start;j<len;j++) 1213 data[0][j] = (float)kernel_function(i,j); 1214 } 1215 return data[0]; 1216 } 1217 get_QD()1218 double[] get_QD() 1219 { 1220 return QD; 1221 } 1222 swap_index(int i, int j)1223 void swap_index(int i, int j) 1224 { 1225 cache.swap_index(i,j); 1226 super.swap_index(i,j); 1227 do {double tmp=QD[i]; QD[i]=QD[j]; QD[j]=tmp;} while(false); 1228 } 1229 } 1230 1231 class SVR_Q extends Kernel 1232 { 1233 private final int l; 1234 private final Cache cache; 1235 private final byte[] sign; 1236 private final int[] index; 1237 private int next_buffer; 1238 private float[][] buffer; 1239 private final double[] QD; 1240 SVR_Q(svm_problem prob, svm_parameter param)1241 SVR_Q(svm_problem prob, svm_parameter param) 1242 { 1243 super(prob.l, prob.x, param); 1244 l = prob.l; 1245 cache = new Cache(l,(long)(param.cache_size*(1<<20))); 1246 QD = new double[2*l]; 1247 sign = new byte[2*l]; 1248 index = new int[2*l]; 1249 for(int k=0;k<l;k++) 1250 { 1251 sign[k] = 1; 1252 sign[k+l] = -1; 1253 index[k] = k; 1254 index[k+l] = k; 1255 QD[k] = kernel_function(k,k); 1256 QD[k+l] = QD[k]; 1257 } 1258 buffer = new float[2][2*l]; 1259 next_buffer = 0; 1260 } 1261 swap_index(int i, int j)1262 void swap_index(int i, int j) 1263 { 1264 do {byte tmp=sign[i]; sign[i]=sign[j]; sign[j]=tmp;} while(false); 1265 do {int tmp=index[i]; index[i]=index[j]; index[j]=tmp;} while(false); 1266 do {double tmp=QD[i]; QD[i]=QD[j]; QD[j]=tmp;} while(false); 1267 } 1268 get_Q(int i, int len)1269 float[] get_Q(int i, int len) 1270 { 1271 float[][] data = new float[1][]; 1272 int j, real_i = index[i]; 1273 if(cache.get_data(real_i,data,l) < l) 1274 { 1275 for(j=0;j<l;j++) 1276 data[0][j] = (float)kernel_function(real_i,j); 1277 } 1278 1279 // reorder and copy 1280 float buf[] = buffer[next_buffer]; 1281 next_buffer = 1 - next_buffer; 1282 byte si = sign[i]; 1283 for(j=0;j<len;j++) 1284 buf[j] = (float) si * sign[j] * data[0][index[j]]; 1285 return buf; 1286 } 1287 get_QD()1288 double[] get_QD() 1289 { 1290 return QD; 1291 } 1292 } 1293 1294 public class svm { 1295 // 1296 // construct and solve various formulations 1297 // 1298 public static final int LIBSVM_VERSION=324; 1299 public static final Random rand = new Random(); 1300 1301 private static svm_print_interface svm_print_stdout = new svm_print_interface() 1302 { 1303 public void print(String s) 1304 { 1305 System.out.print(s); 1306 System.out.flush(); 1307 } 1308 }; 1309 1310 private static svm_print_interface svm_print_string = svm_print_stdout; 1311 info(String s)1312 static void info(String s) 1313 { 1314 svm_print_string.print(s); 1315 } 1316 solve_c_svc(svm_problem prob, svm_parameter param, double[] alpha, Solver.SolutionInfo si, double Cp, double Cn)1317 private static void solve_c_svc(svm_problem prob, svm_parameter param, 1318 double[] alpha, Solver.SolutionInfo si, 1319 double Cp, double Cn) 1320 { 1321 int l = prob.l; 1322 double[] minus_ones = new double[l]; 1323 byte[] y = new byte[l]; 1324 1325 int i; 1326 1327 for(i=0;i<l;i++) 1328 { 1329 alpha[i] = 0; 1330 minus_ones[i] = -1; 1331 if(prob.y[i] > 0) y[i] = +1; else y[i] = -1; 1332 } 1333 1334 Solver s = new Solver(); 1335 s.Solve(l, new SVC_Q(prob,param,y), minus_ones, y, 1336 alpha, Cp, Cn, param.eps, si, param.shrinking); 1337 1338 double sum_alpha=0; 1339 for(i=0;i<l;i++) 1340 sum_alpha += alpha[i]; 1341 1342 if (Cp==Cn) 1343 svm.info("nu = "+sum_alpha/(Cp*prob.l)+"\n"); 1344 1345 for(i=0;i<l;i++) 1346 alpha[i] *= y[i]; 1347 } 1348 solve_nu_svc(svm_problem prob, svm_parameter param, double[] alpha, Solver.SolutionInfo si)1349 private static void solve_nu_svc(svm_problem prob, svm_parameter param, 1350 double[] alpha, Solver.SolutionInfo si) 1351 { 1352 int i; 1353 int l = prob.l; 1354 double nu = param.nu; 1355 1356 byte[] y = new byte[l]; 1357 1358 for(i=0;i<l;i++) 1359 if(prob.y[i]>0) 1360 y[i] = +1; 1361 else 1362 y[i] = -1; 1363 1364 double sum_pos = nu*l/2; 1365 double sum_neg = nu*l/2; 1366 1367 for(i=0;i<l;i++) 1368 if(y[i] == +1) 1369 { 1370 alpha[i] = Math.min(1.0,sum_pos); 1371 sum_pos -= alpha[i]; 1372 } 1373 else 1374 { 1375 alpha[i] = Math.min(1.0,sum_neg); 1376 sum_neg -= alpha[i]; 1377 } 1378 1379 double[] zeros = new double[l]; 1380 1381 for(i=0;i<l;i++) 1382 zeros[i] = 0; 1383 1384 Solver_NU s = new Solver_NU(); 1385 s.Solve(l, new SVC_Q(prob,param,y), zeros, y, 1386 alpha, 1.0, 1.0, param.eps, si, param.shrinking); 1387 double r = si.r; 1388 1389 svm.info("C = "+1/r+"\n"); 1390 1391 for(i=0;i<l;i++) 1392 alpha[i] *= y[i]/r; 1393 1394 si.rho /= r; 1395 si.obj /= (r*r); 1396 si.upper_bound_p = 1/r; 1397 si.upper_bound_n = 1/r; 1398 } 1399 solve_one_class(svm_problem prob, svm_parameter param, double[] alpha, Solver.SolutionInfo si)1400 private static void solve_one_class(svm_problem prob, svm_parameter param, 1401 double[] alpha, Solver.SolutionInfo si) 1402 { 1403 int l = prob.l; 1404 double[] zeros = new double[l]; 1405 byte[] ones = new byte[l]; 1406 int i; 1407 1408 int n = (int)(param.nu*prob.l); // # of alpha's at upper bound 1409 1410 for(i=0;i<n;i++) 1411 alpha[i] = 1; 1412 if(n<prob.l) 1413 alpha[n] = param.nu * prob.l - n; 1414 for(i=n+1;i<l;i++) 1415 alpha[i] = 0; 1416 1417 for(i=0;i<l;i++) 1418 { 1419 zeros[i] = 0; 1420 ones[i] = 1; 1421 } 1422 1423 Solver s = new Solver(); 1424 s.Solve(l, new ONE_CLASS_Q(prob,param), zeros, ones, 1425 alpha, 1.0, 1.0, param.eps, si, param.shrinking); 1426 } 1427 solve_epsilon_svr(svm_problem prob, svm_parameter param, double[] alpha, Solver.SolutionInfo si)1428 private static void solve_epsilon_svr(svm_problem prob, svm_parameter param, 1429 double[] alpha, Solver.SolutionInfo si) 1430 { 1431 int l = prob.l; 1432 double[] alpha2 = new double[2*l]; 1433 double[] linear_term = new double[2*l]; 1434 byte[] y = new byte[2*l]; 1435 int i; 1436 1437 for(i=0;i<l;i++) 1438 { 1439 alpha2[i] = 0; 1440 linear_term[i] = param.p - prob.y[i]; 1441 y[i] = 1; 1442 1443 alpha2[i+l] = 0; 1444 linear_term[i+l] = param.p + prob.y[i]; 1445 y[i+l] = -1; 1446 } 1447 1448 Solver s = new Solver(); 1449 s.Solve(2*l, new SVR_Q(prob,param), linear_term, y, 1450 alpha2, param.C, param.C, param.eps, si, param.shrinking); 1451 1452 double sum_alpha = 0; 1453 for(i=0;i<l;i++) 1454 { 1455 alpha[i] = alpha2[i] - alpha2[i+l]; 1456 sum_alpha += Math.abs(alpha[i]); 1457 } 1458 svm.info("nu = "+sum_alpha/(param.C*l)+"\n"); 1459 } 1460 solve_nu_svr(svm_problem prob, svm_parameter param, double[] alpha, Solver.SolutionInfo si)1461 private static void solve_nu_svr(svm_problem prob, svm_parameter param, 1462 double[] alpha, Solver.SolutionInfo si) 1463 { 1464 int l = prob.l; 1465 double C = param.C; 1466 double[] alpha2 = new double[2*l]; 1467 double[] linear_term = new double[2*l]; 1468 byte[] y = new byte[2*l]; 1469 int i; 1470 1471 double sum = C * param.nu * l / 2; 1472 for(i=0;i<l;i++) 1473 { 1474 alpha2[i] = alpha2[i+l] = Math.min(sum,C); 1475 sum -= alpha2[i]; 1476 1477 linear_term[i] = - prob.y[i]; 1478 y[i] = 1; 1479 1480 linear_term[i+l] = prob.y[i]; 1481 y[i+l] = -1; 1482 } 1483 1484 Solver_NU s = new Solver_NU(); 1485 s.Solve(2*l, new SVR_Q(prob,param), linear_term, y, 1486 alpha2, C, C, param.eps, si, param.shrinking); 1487 1488 svm.info("epsilon = "+(-si.r)+"\n"); 1489 1490 for(i=0;i<l;i++) 1491 alpha[i] = alpha2[i] - alpha2[i+l]; 1492 } 1493 1494 // 1495 // decision_function 1496 // 1497 static class decision_function 1498 { 1499 double[] alpha; 1500 double rho; 1501 }; 1502 svm_train_one( svm_problem prob, svm_parameter param, double Cp, double Cn)1503 static decision_function svm_train_one( 1504 svm_problem prob, svm_parameter param, 1505 double Cp, double Cn) 1506 { 1507 double[] alpha = new double[prob.l]; 1508 Solver.SolutionInfo si = new Solver.SolutionInfo(); 1509 switch(param.svm_type) 1510 { 1511 case svm_parameter.C_SVC: 1512 solve_c_svc(prob,param,alpha,si,Cp,Cn); 1513 break; 1514 case svm_parameter.NU_SVC: 1515 solve_nu_svc(prob,param,alpha,si); 1516 break; 1517 case svm_parameter.ONE_CLASS: 1518 solve_one_class(prob,param,alpha,si); 1519 break; 1520 case svm_parameter.EPSILON_SVR: 1521 solve_epsilon_svr(prob,param,alpha,si); 1522 break; 1523 case svm_parameter.NU_SVR: 1524 solve_nu_svr(prob,param,alpha,si); 1525 break; 1526 } 1527 1528 svm.info("obj = "+si.obj+", rho = "+si.rho+"\n"); 1529 1530 // output SVs 1531 1532 int nSV = 0; 1533 int nBSV = 0; 1534 for(int i=0;i<prob.l;i++) 1535 { 1536 if(Math.abs(alpha[i]) > 0) 1537 { 1538 ++nSV; 1539 if(prob.y[i] > 0) 1540 { 1541 if(Math.abs(alpha[i]) >= si.upper_bound_p) 1542 ++nBSV; 1543 } 1544 else 1545 { 1546 if(Math.abs(alpha[i]) >= si.upper_bound_n) 1547 ++nBSV; 1548 } 1549 } 1550 } 1551 1552 svm.info("nSV = "+nSV+", nBSV = "+nBSV+"\n"); 1553 1554 decision_function f = new decision_function(); 1555 f.alpha = alpha; 1556 f.rho = si.rho; 1557 return f; 1558 } 1559 1560 // Platt's binary SVM Probablistic Output: an improvement from Lin et al. sigmoid_train(int l, double[] dec_values, double[] labels, double[] probAB)1561 private static void sigmoid_train(int l, double[] dec_values, double[] labels, 1562 double[] probAB) 1563 { 1564 double A, B; 1565 double prior1=0, prior0 = 0; 1566 int i; 1567 1568 for (i=0;i<l;i++) 1569 if (labels[i] > 0) prior1+=1; 1570 else prior0+=1; 1571 1572 int max_iter=100; // Maximal number of iterations 1573 double min_step=1e-10; // Minimal step taken in line search 1574 double sigma=1e-12; // For numerically strict PD of Hessian 1575 double eps=1e-5; 1576 double hiTarget=(prior1+1.0)/(prior1+2.0); 1577 double loTarget=1/(prior0+2.0); 1578 double[] t= new double[l]; 1579 double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize; 1580 double newA,newB,newf,d1,d2; 1581 int iter; 1582 1583 // Initial Point and Initial Fun Value 1584 A=0.0; B=Math.log((prior0+1.0)/(prior1+1.0)); 1585 double fval = 0.0; 1586 1587 for (i=0;i<l;i++) 1588 { 1589 if (labels[i]>0) t[i]=hiTarget; 1590 else t[i]=loTarget; 1591 fApB = dec_values[i]*A+B; 1592 if (fApB>=0) 1593 fval += t[i]*fApB + Math.log(1+Math.exp(-fApB)); 1594 else 1595 fval += (t[i] - 1)*fApB +Math.log(1+Math.exp(fApB)); 1596 } 1597 for (iter=0;iter<max_iter;iter++) 1598 { 1599 // Update Gradient and Hessian (use H' = H + sigma I) 1600 h11=sigma; // numerically ensures strict PD 1601 h22=sigma; 1602 h21=0.0;g1=0.0;g2=0.0; 1603 for (i=0;i<l;i++) 1604 { 1605 fApB = dec_values[i]*A+B; 1606 if (fApB >= 0) 1607 { 1608 p=Math.exp(-fApB)/(1.0+Math.exp(-fApB)); 1609 q=1.0/(1.0+Math.exp(-fApB)); 1610 } 1611 else 1612 { 1613 p=1.0/(1.0+Math.exp(fApB)); 1614 q=Math.exp(fApB)/(1.0+Math.exp(fApB)); 1615 } 1616 d2=p*q; 1617 h11+=dec_values[i]*dec_values[i]*d2; 1618 h22+=d2; 1619 h21+=dec_values[i]*d2; 1620 d1=t[i]-p; 1621 g1+=dec_values[i]*d1; 1622 g2+=d1; 1623 } 1624 1625 // Stopping Criteria 1626 if (Math.abs(g1)<eps && Math.abs(g2)<eps) 1627 break; 1628 1629 // Finding Newton direction: -inv(H') * g 1630 det=h11*h22-h21*h21; 1631 dA=-(h22*g1 - h21 * g2) / det; 1632 dB=-(-h21*g1+ h11 * g2) / det; 1633 gd=g1*dA+g2*dB; 1634 1635 1636 stepsize = 1; // Line Search 1637 while (stepsize >= min_step) 1638 { 1639 newA = A + stepsize * dA; 1640 newB = B + stepsize * dB; 1641 1642 // New function value 1643 newf = 0.0; 1644 for (i=0;i<l;i++) 1645 { 1646 fApB = dec_values[i]*newA+newB; 1647 if (fApB >= 0) 1648 newf += t[i]*fApB + Math.log(1+Math.exp(-fApB)); 1649 else 1650 newf += (t[i] - 1)*fApB +Math.log(1+Math.exp(fApB)); 1651 } 1652 // Check sufficient decrease 1653 if (newf<fval+0.0001*stepsize*gd) 1654 { 1655 A=newA;B=newB;fval=newf; 1656 break; 1657 } 1658 else 1659 stepsize = stepsize / 2.0; 1660 } 1661 1662 if (stepsize < min_step) 1663 { 1664 svm.info("Line search fails in two-class probability estimates\n"); 1665 break; 1666 } 1667 } 1668 1669 if (iter>=max_iter) 1670 svm.info("Reaching maximal iterations in two-class probability estimates\n"); 1671 probAB[0]=A;probAB[1]=B; 1672 } 1673 sigmoid_predict(double decision_value, double A, double B)1674 private static double sigmoid_predict(double decision_value, double A, double B) 1675 { 1676 double fApB = decision_value*A+B; 1677 // 1-p used later; avoid catastrophic cancellation 1678 if (fApB >= 0) 1679 return Math.exp(-fApB)/(1.0+Math.exp(-fApB)); 1680 else 1681 return 1.0/(1+Math.exp(fApB)) ; 1682 } 1683 1684 // Method 2 from the multiclass_prob paper by Wu, Lin, and Weng multiclass_probability(int k, double[][] r, double[] p)1685 private static void multiclass_probability(int k, double[][] r, double[] p) 1686 { 1687 int t,j; 1688 int iter = 0, max_iter=Math.max(100,k); 1689 double[][] Q=new double[k][k]; 1690 double[] Qp=new double[k]; 1691 double pQp, eps=0.005/k; 1692 1693 for (t=0;t<k;t++) 1694 { 1695 p[t]=1.0/k; // Valid if k = 1 1696 Q[t][t]=0; 1697 for (j=0;j<t;j++) 1698 { 1699 Q[t][t]+=r[j][t]*r[j][t]; 1700 Q[t][j]=Q[j][t]; 1701 } 1702 for (j=t+1;j<k;j++) 1703 { 1704 Q[t][t]+=r[j][t]*r[j][t]; 1705 Q[t][j]=-r[j][t]*r[t][j]; 1706 } 1707 } 1708 for (iter=0;iter<max_iter;iter++) 1709 { 1710 // stopping condition, recalculate QP,pQP for numerical accuracy 1711 pQp=0; 1712 for (t=0;t<k;t++) 1713 { 1714 Qp[t]=0; 1715 for (j=0;j<k;j++) 1716 Qp[t]+=Q[t][j]*p[j]; 1717 pQp+=p[t]*Qp[t]; 1718 } 1719 double max_error=0; 1720 for (t=0;t<k;t++) 1721 { 1722 double error=Math.abs(Qp[t]-pQp); 1723 if (error>max_error) 1724 max_error=error; 1725 } 1726 if (max_error<eps) break; 1727 1728 for (t=0;t<k;t++) 1729 { 1730 double diff=(-Qp[t]+pQp)/Q[t][t]; 1731 p[t]+=diff; 1732 pQp=(pQp+diff*(diff*Q[t][t]+2*Qp[t]))/(1+diff)/(1+diff); 1733 for (j=0;j<k;j++) 1734 { 1735 Qp[j]=(Qp[j]+diff*Q[t][j])/(1+diff); 1736 p[j]/=(1+diff); 1737 } 1738 } 1739 } 1740 if (iter>=max_iter) 1741 svm.info("Exceeds max_iter in multiclass_prob\n"); 1742 } 1743 1744 // Cross-validation decision values for probability estimates svm_binary_svc_probability(svm_problem prob, svm_parameter param, double Cp, double Cn, double[] probAB)1745 private static void svm_binary_svc_probability(svm_problem prob, svm_parameter param, double Cp, double Cn, double[] probAB) 1746 { 1747 int i; 1748 int nr_fold = 5; 1749 int[] perm = new int[prob.l]; 1750 double[] dec_values = new double[prob.l]; 1751 1752 // random shuffle 1753 for(i=0;i<prob.l;i++) perm[i]=i; 1754 for(i=0;i<prob.l;i++) 1755 { 1756 int j = i+rand.nextInt(prob.l-i); 1757 do {int tmp=perm[i]; perm[i]=perm[j]; perm[j]=tmp;} while(false); 1758 } 1759 for(i=0;i<nr_fold;i++) 1760 { 1761 int begin = i*prob.l/nr_fold; 1762 int end = (i+1)*prob.l/nr_fold; 1763 int j,k; 1764 svm_problem subprob = new svm_problem(); 1765 1766 subprob.l = prob.l-(end-begin); 1767 subprob.x = new svm_node[subprob.l][]; 1768 subprob.y = new double[subprob.l]; 1769 1770 k=0; 1771 for(j=0;j<begin;j++) 1772 { 1773 subprob.x[k] = prob.x[perm[j]]; 1774 subprob.y[k] = prob.y[perm[j]]; 1775 ++k; 1776 } 1777 for(j=end;j<prob.l;j++) 1778 { 1779 subprob.x[k] = prob.x[perm[j]]; 1780 subprob.y[k] = prob.y[perm[j]]; 1781 ++k; 1782 } 1783 int p_count=0,n_count=0; 1784 for(j=0;j<k;j++) 1785 if(subprob.y[j]>0) 1786 p_count++; 1787 else 1788 n_count++; 1789 1790 if(p_count==0 && n_count==0) 1791 for(j=begin;j<end;j++) 1792 dec_values[perm[j]] = 0; 1793 else if(p_count > 0 && n_count == 0) 1794 for(j=begin;j<end;j++) 1795 dec_values[perm[j]] = 1; 1796 else if(p_count == 0 && n_count > 0) 1797 for(j=begin;j<end;j++) 1798 dec_values[perm[j]] = -1; 1799 else 1800 { 1801 svm_parameter subparam = (svm_parameter)param.clone(); 1802 subparam.probability=0; 1803 subparam.C=1.0; 1804 subparam.nr_weight=2; 1805 subparam.weight_label = new int[2]; 1806 subparam.weight = new double[2]; 1807 subparam.weight_label[0]=+1; 1808 subparam.weight_label[1]=-1; 1809 subparam.weight[0]=Cp; 1810 subparam.weight[1]=Cn; 1811 svm_model submodel = svm_train(subprob,subparam); 1812 for(j=begin;j<end;j++) 1813 { 1814 double[] dec_value=new double[1]; 1815 svm_predict_values(submodel,prob.x[perm[j]],dec_value); 1816 dec_values[perm[j]]=dec_value[0]; 1817 // ensure +1 -1 order; reason not using CV subroutine 1818 dec_values[perm[j]] *= submodel.label[0]; 1819 } 1820 } 1821 } 1822 sigmoid_train(prob.l,dec_values,prob.y,probAB); 1823 } 1824 1825 // Return parameter of a Laplace distribution svm_svr_probability(svm_problem prob, svm_parameter param)1826 private static double svm_svr_probability(svm_problem prob, svm_parameter param) 1827 { 1828 int i; 1829 int nr_fold = 5; 1830 double[] ymv = new double[prob.l]; 1831 double mae = 0; 1832 1833 svm_parameter newparam = (svm_parameter)param.clone(); 1834 newparam.probability = 0; 1835 svm_cross_validation(prob,newparam,nr_fold,ymv); 1836 for(i=0;i<prob.l;i++) 1837 { 1838 ymv[i]=prob.y[i]-ymv[i]; 1839 mae += Math.abs(ymv[i]); 1840 } 1841 mae /= prob.l; 1842 double std=Math.sqrt(2*mae*mae); 1843 int count=0; 1844 mae=0; 1845 for(i=0;i<prob.l;i++) 1846 if (Math.abs(ymv[i]) > 5*std) 1847 count=count+1; 1848 else 1849 mae+=Math.abs(ymv[i]); 1850 mae /= (prob.l-count); 1851 svm.info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma="+mae+"\n"); 1852 return mae; 1853 } 1854 1855 // label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data 1856 // perm, length l, must be allocated before calling this subroutine svm_group_classes(svm_problem prob, int[] nr_class_ret, int[][] label_ret, int[][] start_ret, int[][] count_ret, int[] perm)1857 private static void svm_group_classes(svm_problem prob, int[] nr_class_ret, int[][] label_ret, int[][] start_ret, int[][] count_ret, int[] perm) 1858 { 1859 int l = prob.l; 1860 int max_nr_class = 16; 1861 int nr_class = 0; 1862 int[] label = new int[max_nr_class]; 1863 int[] count = new int[max_nr_class]; 1864 int[] data_label = new int[l]; 1865 int i; 1866 1867 for(i=0;i<l;i++) 1868 { 1869 int this_label = (int)(prob.y[i]); 1870 int j; 1871 for(j=0;j<nr_class;j++) 1872 { 1873 if(this_label == label[j]) 1874 { 1875 ++count[j]; 1876 break; 1877 } 1878 } 1879 data_label[i] = j; 1880 if(j == nr_class) 1881 { 1882 if(nr_class == max_nr_class) 1883 { 1884 max_nr_class *= 2; 1885 int[] new_data = new int[max_nr_class]; 1886 System.arraycopy(label,0,new_data,0,label.length); 1887 label = new_data; 1888 new_data = new int[max_nr_class]; 1889 System.arraycopy(count,0,new_data,0,count.length); 1890 count = new_data; 1891 } 1892 label[nr_class] = this_label; 1893 count[nr_class] = 1; 1894 ++nr_class; 1895 } 1896 } 1897 1898 // 1899 // Labels are ordered by their first occurrence in the training set. 1900 // However, for two-class sets with -1/+1 labels and -1 appears first, 1901 // we swap labels to ensure that internally the binary SVM has positive data corresponding to the +1 instances. 1902 // 1903 if (nr_class == 2 && label[0] == -1 && label[1] == +1) 1904 { 1905 do {int tmp=label[0]; label[0]=label[1]; label[1]=tmp;} while(false); 1906 do {int tmp=count[0]; count[0]=count[1]; count[1]=tmp;} while(false); 1907 for(i=0;i<l;i++) 1908 { 1909 if(data_label[i] == 0) 1910 data_label[i] = 1; 1911 else 1912 data_label[i] = 0; 1913 } 1914 } 1915 1916 int[] start = new int[nr_class]; 1917 start[0] = 0; 1918 for(i=1;i<nr_class;i++) 1919 start[i] = start[i-1]+count[i-1]; 1920 for(i=0;i<l;i++) 1921 { 1922 perm[start[data_label[i]]] = i; 1923 ++start[data_label[i]]; 1924 } 1925 start[0] = 0; 1926 for(i=1;i<nr_class;i++) 1927 start[i] = start[i-1]+count[i-1]; 1928 1929 nr_class_ret[0] = nr_class; 1930 label_ret[0] = label; 1931 start_ret[0] = start; 1932 count_ret[0] = count; 1933 } 1934 1935 // 1936 // Interface functions 1937 // svm_train(svm_problem prob, svm_parameter param)1938 public static svm_model svm_train(svm_problem prob, svm_parameter param) 1939 { 1940 svm_model model = new svm_model(); 1941 model.param = param; 1942 1943 if(param.svm_type == svm_parameter.ONE_CLASS || 1944 param.svm_type == svm_parameter.EPSILON_SVR || 1945 param.svm_type == svm_parameter.NU_SVR) 1946 { 1947 // regression or one-class-svm 1948 model.nr_class = 2; 1949 model.label = null; 1950 model.nSV = null; 1951 model.probA = null; model.probB = null; 1952 model.sv_coef = new double[1][]; 1953 1954 if(param.probability == 1 && 1955 (param.svm_type == svm_parameter.EPSILON_SVR || 1956 param.svm_type == svm_parameter.NU_SVR)) 1957 { 1958 model.probA = new double[1]; 1959 model.probA[0] = svm_svr_probability(prob,param); 1960 } 1961 1962 decision_function f = svm_train_one(prob,param,0,0); 1963 model.rho = new double[1]; 1964 model.rho[0] = f.rho; 1965 1966 int nSV = 0; 1967 int i; 1968 for(i=0;i<prob.l;i++) 1969 if(Math.abs(f.alpha[i]) > 0) ++nSV; 1970 model.l = nSV; 1971 model.SV = new svm_node[nSV][]; 1972 model.sv_coef[0] = new double[nSV]; 1973 model.sv_indices = new int[nSV]; 1974 int j = 0; 1975 for(i=0;i<prob.l;i++) 1976 if(Math.abs(f.alpha[i]) > 0) 1977 { 1978 model.SV[j] = prob.x[i]; 1979 model.sv_coef[0][j] = f.alpha[i]; 1980 model.sv_indices[j] = i+1; 1981 ++j; 1982 } 1983 } 1984 else 1985 { 1986 // classification 1987 int l = prob.l; 1988 int[] tmp_nr_class = new int[1]; 1989 int[][] tmp_label = new int[1][]; 1990 int[][] tmp_start = new int[1][]; 1991 int[][] tmp_count = new int[1][]; 1992 int[] perm = new int[l]; 1993 1994 // group training data of the same class 1995 svm_group_classes(prob,tmp_nr_class,tmp_label,tmp_start,tmp_count,perm); 1996 int nr_class = tmp_nr_class[0]; 1997 int[] label = tmp_label[0]; 1998 int[] start = tmp_start[0]; 1999 int[] count = tmp_count[0]; 2000 2001 if(nr_class == 1) 2002 svm.info("WARNING: training data in only one class. See README for details.\n"); 2003 2004 svm_node[][] x = new svm_node[l][]; 2005 int i; 2006 for(i=0;i<l;i++) 2007 x[i] = prob.x[perm[i]]; 2008 2009 // calculate weighted C 2010 2011 double[] weighted_C = new double[nr_class]; 2012 for(i=0;i<nr_class;i++) 2013 weighted_C[i] = param.C; 2014 for(i=0;i<param.nr_weight;i++) 2015 { 2016 int j; 2017 for(j=0;j<nr_class;j++) 2018 if(param.weight_label[i] == label[j]) 2019 break; 2020 if(j == nr_class) 2021 System.err.print("WARNING: class label "+param.weight_label[i]+" specified in weight is not found\n"); 2022 else 2023 weighted_C[j] *= param.weight[i]; 2024 } 2025 2026 // train k*(k-1)/2 models 2027 2028 boolean[] nonzero = new boolean[l]; 2029 for(i=0;i<l;i++) 2030 nonzero[i] = false; 2031 decision_function[] f = new decision_function[nr_class*(nr_class-1)/2]; 2032 2033 double[] probA=null,probB=null; 2034 if (param.probability == 1) 2035 { 2036 probA=new double[nr_class*(nr_class-1)/2]; 2037 probB=new double[nr_class*(nr_class-1)/2]; 2038 } 2039 2040 int p = 0; 2041 for(i=0;i<nr_class;i++) 2042 for(int j=i+1;j<nr_class;j++) 2043 { 2044 svm_problem sub_prob = new svm_problem(); 2045 int si = start[i], sj = start[j]; 2046 int ci = count[i], cj = count[j]; 2047 sub_prob.l = ci+cj; 2048 sub_prob.x = new svm_node[sub_prob.l][]; 2049 sub_prob.y = new double[sub_prob.l]; 2050 int k; 2051 for(k=0;k<ci;k++) 2052 { 2053 sub_prob.x[k] = x[si+k]; 2054 sub_prob.y[k] = +1; 2055 } 2056 for(k=0;k<cj;k++) 2057 { 2058 sub_prob.x[ci+k] = x[sj+k]; 2059 sub_prob.y[ci+k] = -1; 2060 } 2061 2062 if(param.probability == 1) 2063 { 2064 double[] probAB=new double[2]; 2065 svm_binary_svc_probability(sub_prob,param,weighted_C[i],weighted_C[j],probAB); 2066 probA[p]=probAB[0]; 2067 probB[p]=probAB[1]; 2068 } 2069 2070 f[p] = svm_train_one(sub_prob,param,weighted_C[i],weighted_C[j]); 2071 for(k=0;k<ci;k++) 2072 if(!nonzero[si+k] && Math.abs(f[p].alpha[k]) > 0) 2073 nonzero[si+k] = true; 2074 for(k=0;k<cj;k++) 2075 if(!nonzero[sj+k] && Math.abs(f[p].alpha[ci+k]) > 0) 2076 nonzero[sj+k] = true; 2077 ++p; 2078 } 2079 2080 // build output 2081 2082 model.nr_class = nr_class; 2083 2084 model.label = new int[nr_class]; 2085 for(i=0;i<nr_class;i++) 2086 model.label[i] = label[i]; 2087 2088 model.rho = new double[nr_class*(nr_class-1)/2]; 2089 for(i=0;i<nr_class*(nr_class-1)/2;i++) 2090 model.rho[i] = f[i].rho; 2091 2092 if(param.probability == 1) 2093 { 2094 model.probA = new double[nr_class*(nr_class-1)/2]; 2095 model.probB = new double[nr_class*(nr_class-1)/2]; 2096 for(i=0;i<nr_class*(nr_class-1)/2;i++) 2097 { 2098 model.probA[i] = probA[i]; 2099 model.probB[i] = probB[i]; 2100 } 2101 } 2102 else 2103 { 2104 model.probA=null; 2105 model.probB=null; 2106 } 2107 2108 int total_sv = 0; 2109 int[] nz_count = new int[nr_class]; 2110 model.nSV = new int[nr_class]; 2111 for(i=0;i<nr_class;i++) 2112 { 2113 int nSV = 0; 2114 for(int j=0;j<count[i];j++) 2115 if(nonzero[start[i]+j]) 2116 { 2117 ++nSV; 2118 ++total_sv; 2119 } 2120 model.nSV[i] = nSV; 2121 nz_count[i] = nSV; 2122 } 2123 2124 svm.info("Total nSV = "+total_sv+"\n"); 2125 2126 model.l = total_sv; 2127 model.SV = new svm_node[total_sv][]; 2128 model.sv_indices = new int[total_sv]; 2129 p = 0; 2130 for(i=0;i<l;i++) 2131 if(nonzero[i]) 2132 { 2133 model.SV[p] = x[i]; 2134 model.sv_indices[p++] = perm[i] + 1; 2135 } 2136 2137 int[] nz_start = new int[nr_class]; 2138 nz_start[0] = 0; 2139 for(i=1;i<nr_class;i++) 2140 nz_start[i] = nz_start[i-1]+nz_count[i-1]; 2141 2142 model.sv_coef = new double[nr_class-1][]; 2143 for(i=0;i<nr_class-1;i++) 2144 model.sv_coef[i] = new double[total_sv]; 2145 2146 p = 0; 2147 for(i=0;i<nr_class;i++) 2148 for(int j=i+1;j<nr_class;j++) 2149 { 2150 // classifier (i,j): coefficients with 2151 // i are in sv_coef[j-1][nz_start[i]...], 2152 // j are in sv_coef[i][nz_start[j]...] 2153 2154 int si = start[i]; 2155 int sj = start[j]; 2156 int ci = count[i]; 2157 int cj = count[j]; 2158 2159 int q = nz_start[i]; 2160 int k; 2161 for(k=0;k<ci;k++) 2162 if(nonzero[si+k]) 2163 model.sv_coef[j-1][q++] = f[p].alpha[k]; 2164 q = nz_start[j]; 2165 for(k=0;k<cj;k++) 2166 if(nonzero[sj+k]) 2167 model.sv_coef[i][q++] = f[p].alpha[ci+k]; 2168 ++p; 2169 } 2170 } 2171 return model; 2172 } 2173 2174 // Stratified cross validation svm_cross_validation(svm_problem prob, svm_parameter param, int nr_fold, double[] target)2175 public static void svm_cross_validation(svm_problem prob, svm_parameter param, int nr_fold, double[] target) 2176 { 2177 int i; 2178 int[] fold_start = new int[nr_fold+1]; 2179 int l = prob.l; 2180 int[] perm = new int[l]; 2181 2182 // stratified cv may not give leave-one-out rate 2183 // Each class to l folds -> some folds may have zero elements 2184 if((param.svm_type == svm_parameter.C_SVC || 2185 param.svm_type == svm_parameter.NU_SVC) && nr_fold < l) 2186 { 2187 int[] tmp_nr_class = new int[1]; 2188 int[][] tmp_label = new int[1][]; 2189 int[][] tmp_start = new int[1][]; 2190 int[][] tmp_count = new int[1][]; 2191 2192 svm_group_classes(prob,tmp_nr_class,tmp_label,tmp_start,tmp_count,perm); 2193 2194 int nr_class = tmp_nr_class[0]; 2195 int[] start = tmp_start[0]; 2196 int[] count = tmp_count[0]; 2197 2198 // random shuffle and then data grouped by fold using the array perm 2199 int[] fold_count = new int[nr_fold]; 2200 int c; 2201 int[] index = new int[l]; 2202 for(i=0;i<l;i++) 2203 index[i]=perm[i]; 2204 for (c=0; c<nr_class; c++) 2205 for(i=0;i<count[c];i++) 2206 { 2207 int j = i+rand.nextInt(count[c]-i); 2208 do {int tmp=index[start[c]+j]; index[start[c]+j]=index[start[c]+i]; index[start[c]+i]=tmp;} while(false); 2209 } 2210 for(i=0;i<nr_fold;i++) 2211 { 2212 fold_count[i] = 0; 2213 for (c=0; c<nr_class;c++) 2214 fold_count[i]+=(i+1)*count[c]/nr_fold-i*count[c]/nr_fold; 2215 } 2216 fold_start[0]=0; 2217 for (i=1;i<=nr_fold;i++) 2218 fold_start[i] = fold_start[i-1]+fold_count[i-1]; 2219 for (c=0; c<nr_class;c++) 2220 for(i=0;i<nr_fold;i++) 2221 { 2222 int begin = start[c]+i*count[c]/nr_fold; 2223 int end = start[c]+(i+1)*count[c]/nr_fold; 2224 for(int j=begin;j<end;j++) 2225 { 2226 perm[fold_start[i]] = index[j]; 2227 fold_start[i]++; 2228 } 2229 } 2230 fold_start[0]=0; 2231 for (i=1;i<=nr_fold;i++) 2232 fold_start[i] = fold_start[i-1]+fold_count[i-1]; 2233 } 2234 else 2235 { 2236 for(i=0;i<l;i++) perm[i]=i; 2237 for(i=0;i<l;i++) 2238 { 2239 int j = i+rand.nextInt(l-i); 2240 do {int tmp=perm[i]; perm[i]=perm[j]; perm[j]=tmp;} while(false); 2241 } 2242 for(i=0;i<=nr_fold;i++) 2243 fold_start[i]=i*l/nr_fold; 2244 } 2245 2246 for(i=0;i<nr_fold;i++) 2247 { 2248 int begin = fold_start[i]; 2249 int end = fold_start[i+1]; 2250 int j,k; 2251 svm_problem subprob = new svm_problem(); 2252 2253 subprob.l = l-(end-begin); 2254 subprob.x = new svm_node[subprob.l][]; 2255 subprob.y = new double[subprob.l]; 2256 2257 k=0; 2258 for(j=0;j<begin;j++) 2259 { 2260 subprob.x[k] = prob.x[perm[j]]; 2261 subprob.y[k] = prob.y[perm[j]]; 2262 ++k; 2263 } 2264 for(j=end;j<l;j++) 2265 { 2266 subprob.x[k] = prob.x[perm[j]]; 2267 subprob.y[k] = prob.y[perm[j]]; 2268 ++k; 2269 } 2270 svm_model submodel = svm_train(subprob,param); 2271 if(param.probability==1 && 2272 (param.svm_type == svm_parameter.C_SVC || 2273 param.svm_type == svm_parameter.NU_SVC)) 2274 { 2275 double[] prob_estimates= new double[svm_get_nr_class(submodel)]; 2276 for(j=begin;j<end;j++) 2277 target[perm[j]] = svm_predict_probability(submodel,prob.x[perm[j]],prob_estimates); 2278 } 2279 else 2280 for(j=begin;j<end;j++) 2281 target[perm[j]] = svm_predict(submodel,prob.x[perm[j]]); 2282 } 2283 } 2284 svm_get_svm_type(svm_model model)2285 public static int svm_get_svm_type(svm_model model) 2286 { 2287 return model.param.svm_type; 2288 } 2289 svm_get_nr_class(svm_model model)2290 public static int svm_get_nr_class(svm_model model) 2291 { 2292 return model.nr_class; 2293 } 2294 svm_get_labels(svm_model model, int[] label)2295 public static void svm_get_labels(svm_model model, int[] label) 2296 { 2297 if (model.label != null) 2298 for(int i=0;i<model.nr_class;i++) 2299 label[i] = model.label[i]; 2300 } 2301 svm_get_sv_indices(svm_model model, int[] indices)2302 public static void svm_get_sv_indices(svm_model model, int[] indices) 2303 { 2304 if (model.sv_indices != null) 2305 for(int i=0;i<model.l;i++) 2306 indices[i] = model.sv_indices[i]; 2307 } 2308 svm_get_nr_sv(svm_model model)2309 public static int svm_get_nr_sv(svm_model model) 2310 { 2311 return model.l; 2312 } 2313 svm_get_svr_probability(svm_model model)2314 public static double svm_get_svr_probability(svm_model model) 2315 { 2316 if ((model.param.svm_type == svm_parameter.EPSILON_SVR || model.param.svm_type == svm_parameter.NU_SVR) && 2317 model.probA!=null) 2318 return model.probA[0]; 2319 else 2320 { 2321 System.err.print("Model doesn't contain information for SVR probability inference\n"); 2322 return 0; 2323 } 2324 } 2325 svm_predict_values(svm_model model, svm_node[] x, double[] dec_values)2326 public static double svm_predict_values(svm_model model, svm_node[] x, double[] dec_values) 2327 { 2328 int i; 2329 if(model.param.svm_type == svm_parameter.ONE_CLASS || 2330 model.param.svm_type == svm_parameter.EPSILON_SVR || 2331 model.param.svm_type == svm_parameter.NU_SVR) 2332 { 2333 double[] sv_coef = model.sv_coef[0]; 2334 double sum = 0; 2335 for(i=0;i<model.l;i++) 2336 sum += sv_coef[i] * Kernel.k_function(x,model.SV[i],model.param); 2337 sum -= model.rho[0]; 2338 dec_values[0] = sum; 2339 2340 if(model.param.svm_type == svm_parameter.ONE_CLASS) 2341 return (sum>0)?1:-1; 2342 else 2343 return sum; 2344 } 2345 else 2346 { 2347 int nr_class = model.nr_class; 2348 int l = model.l; 2349 2350 double[] kvalue = new double[l]; 2351 for(i=0;i<l;i++) 2352 kvalue[i] = Kernel.k_function(x,model.SV[i],model.param); 2353 2354 int[] start = new int[nr_class]; 2355 start[0] = 0; 2356 for(i=1;i<nr_class;i++) 2357 start[i] = start[i-1]+model.nSV[i-1]; 2358 2359 int[] vote = new int[nr_class]; 2360 for(i=0;i<nr_class;i++) 2361 vote[i] = 0; 2362 2363 int p=0; 2364 for(i=0;i<nr_class;i++) 2365 for(int j=i+1;j<nr_class;j++) 2366 { 2367 double sum = 0; 2368 int si = start[i]; 2369 int sj = start[j]; 2370 int ci = model.nSV[i]; 2371 int cj = model.nSV[j]; 2372 2373 int k; 2374 double[] coef1 = model.sv_coef[j-1]; 2375 double[] coef2 = model.sv_coef[i]; 2376 for(k=0;k<ci;k++) 2377 sum += coef1[si+k] * kvalue[si+k]; 2378 for(k=0;k<cj;k++) 2379 sum += coef2[sj+k] * kvalue[sj+k]; 2380 sum -= model.rho[p]; 2381 dec_values[p] = sum; 2382 2383 if(dec_values[p] > 0) 2384 ++vote[i]; 2385 else 2386 ++vote[j]; 2387 p++; 2388 } 2389 2390 int vote_max_idx = 0; 2391 for(i=1;i<nr_class;i++) 2392 if(vote[i] > vote[vote_max_idx]) 2393 vote_max_idx = i; 2394 2395 return model.label[vote_max_idx]; 2396 } 2397 } 2398 svm_predict(svm_model model, svm_node[] x)2399 public static double svm_predict(svm_model model, svm_node[] x) 2400 { 2401 int nr_class = model.nr_class; 2402 double[] dec_values; 2403 if(model.param.svm_type == svm_parameter.ONE_CLASS || 2404 model.param.svm_type == svm_parameter.EPSILON_SVR || 2405 model.param.svm_type == svm_parameter.NU_SVR) 2406 dec_values = new double[1]; 2407 else 2408 dec_values = new double[nr_class*(nr_class-1)/2]; 2409 double pred_result = svm_predict_values(model, x, dec_values); 2410 return pred_result; 2411 } 2412 svm_predict_probability(svm_model model, svm_node[] x, double[] prob_estimates)2413 public static double svm_predict_probability(svm_model model, svm_node[] x, double[] prob_estimates) 2414 { 2415 if ((model.param.svm_type == svm_parameter.C_SVC || model.param.svm_type == svm_parameter.NU_SVC) && 2416 model.probA!=null && model.probB!=null) 2417 { 2418 int i; 2419 int nr_class = model.nr_class; 2420 double[] dec_values = new double[nr_class*(nr_class-1)/2]; 2421 svm_predict_values(model, x, dec_values); 2422 2423 double min_prob=1e-7; 2424 double[][] pairwise_prob=new double[nr_class][nr_class]; 2425 2426 int k=0; 2427 for(i=0;i<nr_class;i++) 2428 for(int j=i+1;j<nr_class;j++) 2429 { 2430 pairwise_prob[i][j]=Math.min(Math.max(sigmoid_predict(dec_values[k],model.probA[k],model.probB[k]),min_prob),1-min_prob); 2431 pairwise_prob[j][i]=1-pairwise_prob[i][j]; 2432 k++; 2433 } 2434 if (nr_class == 2) 2435 { 2436 prob_estimates[0] = pairwise_prob[0][1]; 2437 prob_estimates[1] = pairwise_prob[1][0]; 2438 } 2439 else 2440 multiclass_probability(nr_class,pairwise_prob,prob_estimates); 2441 2442 int prob_max_idx = 0; 2443 for(i=1;i<nr_class;i++) 2444 if(prob_estimates[i] > prob_estimates[prob_max_idx]) 2445 prob_max_idx = i; 2446 return model.label[prob_max_idx]; 2447 } 2448 else 2449 return svm_predict(model, x); 2450 } 2451 2452 static final String svm_type_table[] = 2453 { 2454 "c_svc","nu_svc","one_class","epsilon_svr","nu_svr", 2455 }; 2456 2457 static final String kernel_type_table[]= 2458 { 2459 "linear","polynomial","rbf","sigmoid","precomputed" 2460 }; 2461 svm_save_model(String model_file_name, svm_model model)2462 public static void svm_save_model(String model_file_name, svm_model model) throws IOException 2463 { 2464 DataOutputStream fp = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(model_file_name))); 2465 2466 svm_parameter param = model.param; 2467 2468 fp.writeBytes("svm_type "+svm_type_table[param.svm_type]+"\n"); 2469 fp.writeBytes("kernel_type "+kernel_type_table[param.kernel_type]+"\n"); 2470 2471 if(param.kernel_type == svm_parameter.POLY) 2472 fp.writeBytes("degree "+param.degree+"\n"); 2473 2474 if(param.kernel_type == svm_parameter.POLY || 2475 param.kernel_type == svm_parameter.RBF || 2476 param.kernel_type == svm_parameter.SIGMOID) 2477 fp.writeBytes("gamma "+param.gamma+"\n"); 2478 2479 if(param.kernel_type == svm_parameter.POLY || 2480 param.kernel_type == svm_parameter.SIGMOID) 2481 fp.writeBytes("coef0 "+param.coef0+"\n"); 2482 2483 int nr_class = model.nr_class; 2484 int l = model.l; 2485 fp.writeBytes("nr_class "+nr_class+"\n"); 2486 fp.writeBytes("total_sv "+l+"\n"); 2487 2488 { 2489 fp.writeBytes("rho"); 2490 for(int i=0;i<nr_class*(nr_class-1)/2;i++) 2491 fp.writeBytes(" "+model.rho[i]); 2492 fp.writeBytes("\n"); 2493 } 2494 2495 if(model.label != null) 2496 { 2497 fp.writeBytes("label"); 2498 for(int i=0;i<nr_class;i++) 2499 fp.writeBytes(" "+model.label[i]); 2500 fp.writeBytes("\n"); 2501 } 2502 2503 if(model.probA != null) // regression has probA only 2504 { 2505 fp.writeBytes("probA"); 2506 for(int i=0;i<nr_class*(nr_class-1)/2;i++) 2507 fp.writeBytes(" "+model.probA[i]); 2508 fp.writeBytes("\n"); 2509 } 2510 if(model.probB != null) 2511 { 2512 fp.writeBytes("probB"); 2513 for(int i=0;i<nr_class*(nr_class-1)/2;i++) 2514 fp.writeBytes(" "+model.probB[i]); 2515 fp.writeBytes("\n"); 2516 } 2517 2518 if(model.nSV != null) 2519 { 2520 fp.writeBytes("nr_sv"); 2521 for(int i=0;i<nr_class;i++) 2522 fp.writeBytes(" "+model.nSV[i]); 2523 fp.writeBytes("\n"); 2524 } 2525 2526 fp.writeBytes("SV\n"); 2527 double[][] sv_coef = model.sv_coef; 2528 svm_node[][] SV = model.SV; 2529 2530 for(int i=0;i<l;i++) 2531 { 2532 for(int j=0;j<nr_class-1;j++) 2533 fp.writeBytes(sv_coef[j][i]+" "); 2534 2535 svm_node[] p = SV[i]; 2536 if(param.kernel_type == svm_parameter.PRECOMPUTED) 2537 fp.writeBytes("0:"+(int)(p[0].value)); 2538 else 2539 for(int j=0;j<p.length;j++) 2540 fp.writeBytes(p[j].index+":"+p[j].value+" "); 2541 fp.writeBytes("\n"); 2542 } 2543 2544 fp.close(); 2545 } 2546 atof(String s)2547 private static double atof(String s) 2548 { 2549 return Double.valueOf(s).doubleValue(); 2550 } 2551 atoi(String s)2552 private static int atoi(String s) 2553 { 2554 return Integer.parseInt(s); 2555 } 2556 read_model_header(BufferedReader fp, svm_model model)2557 private static boolean read_model_header(BufferedReader fp, svm_model model) 2558 { 2559 svm_parameter param = new svm_parameter(); 2560 model.param = param; 2561 // parameters for training only won't be assigned, but arrays are assigned as null for safety 2562 param.nr_weight = 0; 2563 param.weight_label = null; 2564 param.weight = null; 2565 2566 try 2567 { 2568 while(true) 2569 { 2570 String cmd = fp.readLine(); 2571 String arg = cmd.substring(cmd.indexOf(' ')+1); 2572 2573 if(cmd.startsWith("svm_type")) 2574 { 2575 int i; 2576 for(i=0;i<svm_type_table.length;i++) 2577 { 2578 if(arg.indexOf(svm_type_table[i])!=-1) 2579 { 2580 param.svm_type=i; 2581 break; 2582 } 2583 } 2584 if(i == svm_type_table.length) 2585 { 2586 System.err.print("unknown svm type.\n"); 2587 return false; 2588 } 2589 } 2590 else if(cmd.startsWith("kernel_type")) 2591 { 2592 int i; 2593 for(i=0;i<kernel_type_table.length;i++) 2594 { 2595 if(arg.indexOf(kernel_type_table[i])!=-1) 2596 { 2597 param.kernel_type=i; 2598 break; 2599 } 2600 } 2601 if(i == kernel_type_table.length) 2602 { 2603 System.err.print("unknown kernel function.\n"); 2604 return false; 2605 } 2606 } 2607 else if(cmd.startsWith("degree")) 2608 param.degree = atoi(arg); 2609 else if(cmd.startsWith("gamma")) 2610 param.gamma = atof(arg); 2611 else if(cmd.startsWith("coef0")) 2612 param.coef0 = atof(arg); 2613 else if(cmd.startsWith("nr_class")) 2614 model.nr_class = atoi(arg); 2615 else if(cmd.startsWith("total_sv")) 2616 model.l = atoi(arg); 2617 else if(cmd.startsWith("rho")) 2618 { 2619 int n = model.nr_class * (model.nr_class-1)/2; 2620 model.rho = new double[n]; 2621 StringTokenizer st = new StringTokenizer(arg); 2622 for(int i=0;i<n;i++) 2623 model.rho[i] = atof(st.nextToken()); 2624 } 2625 else if(cmd.startsWith("label")) 2626 { 2627 int n = model.nr_class; 2628 model.label = new int[n]; 2629 StringTokenizer st = new StringTokenizer(arg); 2630 for(int i=0;i<n;i++) 2631 model.label[i] = atoi(st.nextToken()); 2632 } 2633 else if(cmd.startsWith("probA")) 2634 { 2635 int n = model.nr_class*(model.nr_class-1)/2; 2636 model.probA = new double[n]; 2637 StringTokenizer st = new StringTokenizer(arg); 2638 for(int i=0;i<n;i++) 2639 model.probA[i] = atof(st.nextToken()); 2640 } 2641 else if(cmd.startsWith("probB")) 2642 { 2643 int n = model.nr_class*(model.nr_class-1)/2; 2644 model.probB = new double[n]; 2645 StringTokenizer st = new StringTokenizer(arg); 2646 for(int i=0;i<n;i++) 2647 model.probB[i] = atof(st.nextToken()); 2648 } 2649 else if(cmd.startsWith("nr_sv")) 2650 { 2651 int n = model.nr_class; 2652 model.nSV = new int[n]; 2653 StringTokenizer st = new StringTokenizer(arg); 2654 for(int i=0;i<n;i++) 2655 model.nSV[i] = atoi(st.nextToken()); 2656 } 2657 else if(cmd.startsWith("SV")) 2658 { 2659 break; 2660 } 2661 else 2662 { 2663 System.err.print("unknown text in model file: ["+cmd+"]\n"); 2664 return false; 2665 } 2666 } 2667 } 2668 catch(Exception e) 2669 { 2670 return false; 2671 } 2672 return true; 2673 } 2674 svm_load_model(String model_file_name)2675 public static svm_model svm_load_model(String model_file_name) throws IOException 2676 { 2677 return svm_load_model(new BufferedReader(new FileReader(model_file_name))); 2678 } 2679 svm_load_model(BufferedReader fp)2680 public static svm_model svm_load_model(BufferedReader fp) throws IOException 2681 { 2682 // read parameters 2683 2684 svm_model model = new svm_model(); 2685 model.rho = null; 2686 model.probA = null; 2687 model.probB = null; 2688 model.label = null; 2689 model.nSV = null; 2690 2691 // read header 2692 if (!read_model_header(fp, model)) 2693 { 2694 System.err.print("ERROR: failed to read model\n"); 2695 return null; 2696 } 2697 2698 // read sv_coef and SV 2699 2700 int m = model.nr_class - 1; 2701 int l = model.l; 2702 model.sv_coef = new double[m][l]; 2703 model.SV = new svm_node[l][]; 2704 2705 for(int i=0;i<l;i++) 2706 { 2707 String line = fp.readLine(); 2708 StringTokenizer st = new StringTokenizer(line," \t\n\r\f:"); 2709 2710 for(int k=0;k<m;k++) 2711 model.sv_coef[k][i] = atof(st.nextToken()); 2712 int n = st.countTokens()/2; 2713 model.SV[i] = new svm_node[n]; 2714 for(int j=0;j<n;j++) 2715 { 2716 model.SV[i][j] = new svm_node(); 2717 model.SV[i][j].index = atoi(st.nextToken()); 2718 model.SV[i][j].value = atof(st.nextToken()); 2719 } 2720 } 2721 2722 fp.close(); 2723 return model; 2724 } 2725 svm_check_parameter(svm_problem prob, svm_parameter param)2726 public static String svm_check_parameter(svm_problem prob, svm_parameter param) 2727 { 2728 // svm_type 2729 2730 int svm_type = param.svm_type; 2731 if(svm_type != svm_parameter.C_SVC && 2732 svm_type != svm_parameter.NU_SVC && 2733 svm_type != svm_parameter.ONE_CLASS && 2734 svm_type != svm_parameter.EPSILON_SVR && 2735 svm_type != svm_parameter.NU_SVR) 2736 return "unknown svm type"; 2737 2738 // kernel_type, degree 2739 2740 int kernel_type = param.kernel_type; 2741 if(kernel_type != svm_parameter.LINEAR && 2742 kernel_type != svm_parameter.POLY && 2743 kernel_type != svm_parameter.RBF && 2744 kernel_type != svm_parameter.SIGMOID && 2745 kernel_type != svm_parameter.PRECOMPUTED) 2746 return "unknown kernel type"; 2747 2748 if((kernel_type == svm_parameter.POLY || 2749 kernel_type == svm_parameter.RBF || 2750 kernel_type == svm_parameter.SIGMOID) && 2751 param.gamma < 0) 2752 return "gamma < 0"; 2753 2754 if(kernel_type == svm_parameter.POLY && param.degree < 0) 2755 return "degree of polynomial kernel < 0"; 2756 2757 // cache_size,eps,C,nu,p,shrinking 2758 2759 if(param.cache_size <= 0) 2760 return "cache_size <= 0"; 2761 2762 if(param.eps <= 0) 2763 return "eps <= 0"; 2764 2765 if(svm_type == svm_parameter.C_SVC || 2766 svm_type == svm_parameter.EPSILON_SVR || 2767 svm_type == svm_parameter.NU_SVR) 2768 if(param.C <= 0) 2769 return "C <= 0"; 2770 2771 if(svm_type == svm_parameter.NU_SVC || 2772 svm_type == svm_parameter.ONE_CLASS || 2773 svm_type == svm_parameter.NU_SVR) 2774 if(param.nu <= 0 || param.nu > 1) 2775 return "nu <= 0 or nu > 1"; 2776 2777 if(svm_type == svm_parameter.EPSILON_SVR) 2778 if(param.p < 0) 2779 return "p < 0"; 2780 2781 if(param.shrinking != 0 && 2782 param.shrinking != 1) 2783 return "shrinking != 0 and shrinking != 1"; 2784 2785 if(param.probability != 0 && 2786 param.probability != 1) 2787 return "probability != 0 and probability != 1"; 2788 2789 if(param.probability == 1 && 2790 svm_type == svm_parameter.ONE_CLASS) 2791 return "one-class SVM probability output not supported yet"; 2792 2793 // check whether nu-svc is feasible 2794 2795 if(svm_type == svm_parameter.NU_SVC) 2796 { 2797 int l = prob.l; 2798 int max_nr_class = 16; 2799 int nr_class = 0; 2800 int[] label = new int[max_nr_class]; 2801 int[] count = new int[max_nr_class]; 2802 2803 int i; 2804 for(i=0;i<l;i++) 2805 { 2806 int this_label = (int)prob.y[i]; 2807 int j; 2808 for(j=0;j<nr_class;j++) 2809 if(this_label == label[j]) 2810 { 2811 ++count[j]; 2812 break; 2813 } 2814 2815 if(j == nr_class) 2816 { 2817 if(nr_class == max_nr_class) 2818 { 2819 max_nr_class *= 2; 2820 int[] new_data = new int[max_nr_class]; 2821 System.arraycopy(label,0,new_data,0,label.length); 2822 label = new_data; 2823 2824 new_data = new int[max_nr_class]; 2825 System.arraycopy(count,0,new_data,0,count.length); 2826 count = new_data; 2827 } 2828 label[nr_class] = this_label; 2829 count[nr_class] = 1; 2830 ++nr_class; 2831 } 2832 } 2833 2834 for(i=0;i<nr_class;i++) 2835 { 2836 int n1 = count[i]; 2837 for(int j=i+1;j<nr_class;j++) 2838 { 2839 int n2 = count[j]; 2840 if(param.nu*(n1+n2)/2 > Math.min(n1,n2)) 2841 return "specified nu is infeasible"; 2842 } 2843 } 2844 } 2845 2846 return null; 2847 } 2848 svm_check_probability_model(svm_model model)2849 public static int svm_check_probability_model(svm_model model) 2850 { 2851 if (((model.param.svm_type == svm_parameter.C_SVC || model.param.svm_type == svm_parameter.NU_SVC) && 2852 model.probA!=null && model.probB!=null) || 2853 ((model.param.svm_type == svm_parameter.EPSILON_SVR || model.param.svm_type == svm_parameter.NU_SVR) && 2854 model.probA!=null)) 2855 return 1; 2856 else 2857 return 0; 2858 } 2859 svm_set_print_string_function(svm_print_interface print_func)2860 public static void svm_set_print_string_function(svm_print_interface print_func) 2861 { 2862 if (print_func == null) 2863 svm_print_string = svm_print_stdout; 2864 else 2865 svm_print_string = print_func; 2866 } 2867 } 2868