1function [CC]=train_sc(D,classlabel,MODE,W) 2% Train a (statistical) classifier 3% 4% CC = train_sc(D,classlabel) 5% CC = train_sc(D,classlabel,MODE) 6% CC = train_sc(D,classlabel,MODE, W) 7% weighting D(k,:) with weight W(k) (not all classifiers supported weighting) 8% 9% CC contains the model parameters of a classifier which can be applied 10% to test data using test_sc. 11% R = test_sc(CC,D,...) 12% 13% D training samples (each row is a sample, each column is a feature) 14% classlabel labels of each sample, must have the same number of rows as D. 15% Two different encodings are supported: 16% {-1,1}-encoding (multiple classes with separate columns for each class) or 17% 1..M encoding. 18% So [1;2;3;1;4] is equivalent to 19% [+1,-1,-1,-1; 20% [-1,+1,-1,-1; 21% [-1,-1,+1,-1; 22% [+1,-1,-1,-1] 23% [-1,-1,-1,+1] 24% Note, samples with classlabel=0 are ignored. 25% 26% The following classifier types are supported MODE.TYPE 27% 'MDA' mahalanobis distance based classifier [1] 28% 'MD2' mahalanobis distance based classifier [1] 29% 'MD3' mahalanobis distance based classifier [1] 30% 'GRB' Gaussian radial basis function [1] 31% 'QDA' quadratic discriminant analysis [1] 32% 'LD2' linear discriminant analysis (see LDBC2) [1] 33% MODE.hyperparameter.gamma: regularization parameter [default 0] 34% 'LD3', 'FDA', 'LDA', 'FLDA' 35% linear discriminant analysis (see LDBC3) [1] 36% MODE.hyperparameter.gamma: regularization parameter [default 0] 37% 'LD4' linear discriminant analysis (see LDBC4) [1] 38% MODE.hyperparameter.gamma: regularization parameter [default 0] 39% 'LD5' another LDA (motivated by CSP) 40% MODE.hyperparameter.gamma: regularization parameter [default 0] 41% 'RDA' regularized discriminant analysis [7] 42% MODE.hyperparameter.gamma: regularization parameter 43% MODE.hyperparameter.lambda = 44% gamma = 0, lambda = 0 : MDA 45% gamma = 0, lambda = 1 : LDA [default] 46% Hint: hyperparameter are used only in test_sc.m, testing different 47% the hyperparameters do not need repetitive calls to train_sc, 48% it is sufficient to modify CC.hyperparameter before calling test_sc. 49% 'GDBC' general distance based classifier [1] 50% '' statistical classifier, requires Mode argument in TEST_SC 51% '###/DELETION' if the data contains missing values (encoded as NaNs), 52% a row-wise or column-wise deletion (depending on which method 53% removes less data values) is applied; 54% '###/GSVD' GSVD and statistical classifier [2,3], 55% '###/sparse' sparse [5] 56% '###' must be 'LDA' or any other classifier 57% 'PLS' (linear) partial least squares regression 58% 'REG' regression analysis; 59% 'WienerHopf' Wiener-Hopf equation 60% 'NBC' Naive Bayesian Classifier [6] 61% 'aNBC' Augmented Naive Bayesian Classifier [6] 62% 'NBPW' Naive Bayesian Parzen Window [9] 63% 64% 'PLA' Perceptron Learning Algorithm [11] 65% MODE.hyperparameter.alpha = alpha [default: 1] 66% w = w + alpha * e'*x 67% 'LMS', 'AdaLine' Least mean squares, adaptive line element, Widrow-Hoff, delta rule 68% MODE.hyperparameter.alpha = alpha [default: 1] 69% 'Winnow2' Winnow2 algorithm [12] 70% 71% 'PSVM' Proximal SVM [8] 72% MODE.hyperparameter.nu (default: 1.0) 73% 'LPM' Linear Programming Machine 74% uses and requires train_LPM of the iLog CPLEX optimizer 75% MODE.hyperparameter.c_value = 76% 'CSP' CommonSpatialPattern is very experimental and just a hack 77% uses a smoothing window of 50 samples. 78% 'SVM','SVM1r' support vector machines, one-vs-rest 79% MODE.hyperparameter.c_value = 80% 'SVM11' support vector machines, one-vs-one + voting 81% MODE.hyperparameter.c_value = 82% 'RBF' Support Vector Machines with RBF Kernel 83% MODE.hyperparameter.c_value = 84% MODE.hyperparameter.gamma = 85% 'SVM:LIB' libSVM [default SVM algorithm) 86% 'SVM:bioinfo' uses and requires svmtrain from the bioinfo toolbox 87% 'SVM:OSU' uses and requires mexSVMTrain from the OSU-SVM toolbox 88% 'SVM:LOO' uses and requires svcm_train from the LOO-SVM toolbox 89% 'SVM:Gunn' uses and requires svc-functios from the Gunn-SVM toolbox 90% 'SVM:KM' uses and requires svmclass-function from the KM-SVM toolbox 91% 'SVM:LINz' LibLinear [10] (requires train.mex from LibLinear somewhere in the path) 92% z=0 (default) LibLinear with -- L2-regularized logistic regression 93% z=1 LibLinear with -- L2-loss support vector machines (dual) 94% z=2 LibLinear with -- L2-loss support vector machines (primal) 95% z=3 LibLinear with -- L1-loss support vector machines (dual) 96% 'SVM:LIN4' LibLinear with -- multi-class support vector machines by Crammer and Singer 97% 'DT' decision tree - not implemented yet. 98% 99% {'REG','MDA','MD2','QDA','QDA2','LD2','LD3','LD4','LD5','LD6','NBC','aNBC','WienerHopf','LDA/GSVD','MDA/GSVD', 'LDA/sparse','MDA/sparse', 'PLA', 'LMS','LDA/DELETION','MDA/DELETION','NBC/DELETION','RDA/DELETION','REG/DELETION','RDA','GDBC','SVM','RBF','PSVM','SVM11','SVM:LIN4','SVM:LIN0','SVM:LIN1','SVM:LIN2','SVM:LIN3','WINNOW', 'DT'}; 100% 101% CC contains the model parameters of a classifier. Some time ago, 102% CC was a statistical classifier containing the mean 103% and the covariance of the data of each class (encoded in the 104% so-called "extended covariance matrices". Nowadays, also other 105% classifiers are supported. 106% 107% see also: TEST_SC, COVM, ROW_COL_DELETION 108% 109% References: 110% [1] R. Duda, P. Hart, and D. Stork, Pattern Classification, second ed. 111% John Wiley & Sons, 2001. 112% [2] Peg Howland and Haesun Park, 113% Generalizing Discriminant Analysis Using the Generalized Singular Value Decomposition 114% IEEE Transactions on Pattern Analysis and Machine Intelligence, 26(8), 2004. 115% dx.doi.org/10.1109/TPAMI.2004.46 116% [3] http://www-static.cc.gatech.edu/~kihwan23/face_recog_gsvd.htm 117% [4] Jieping Ye, Ravi Janardan, Cheong Hee Park, Haesun Park 118% A new optimization criterion for generalized discriminant analysis on undersampled problems. 119% The Third IEEE International Conference on Data Mining, Melbourne, Florida, USA 120% November 19 - 22, 2003 121% [5] J.D. Tebbens and P. Schlesinger (2006), 122% Improving Implementation of Linear Discriminant Analysis for the Small Sample Size Problem 123% Computational Statistics & Data Analysis, vol 52(1): 423-437, 2007 124% http://www.cs.cas.cz/mweb/download/publi/JdtSchl2006.pdf 125% [6] H. Zhang, The optimality of Naive Bayes, 126% http://www.cs.unb.ca/profs/hzhang/publications/FLAIRS04ZhangH.pdf 127% [7] J.H. Friedman. Regularized discriminant analysis. 128% Journal of the American Statistical Association, 84:165–175, 1989. 129% [8] G. Fung and O.L. Mangasarian, Proximal Support Vector Machine Classifiers, KDD 2001. 130% Eds. F. Provost and R. Srikant, Proc. KDD-2001: Knowledge Discovery and Data Mining, August 26-29, 2001, San Francisco, CA. 131% p. 77-86. 132% [9] Kai Keng Ang, Zhang Yang Chin, Haihong Zhang, Cuntai Guan. 133% Filter Bank Common Spatial Pattern (FBCSP) in Brain-Computer Interface. 134% IEEE International Joint Conference on Neural Networks, 2008. IJCNN 2008. (IEEE World Congress on Computational Intelligence). 135% 1-8 June 2008 Page(s):2390 - 2397 136% [10] R.-E. Fan, K.-W. Chang, C.-J. Hsieh, X.-R. Wang, and C.-J. Lin. 137% LIBLINEAR: A Library for Large Linear Classification, Journal of Machine Learning Research 9(2008), 1871-1874. 138% Software available at http://www.csie.ntu.edu.tw/~cjlin/liblinear 139% [11] http://en.wikipedia.org/wiki/Perceptron#Learning_algorithm 140% [12] Littlestone, N. (1988) 141% "Learning Quickly When Irrelevant Attributes Abound: A New Linear-threshold Algorithm" 142% Machine Learning 285-318(2) 143% http://en.wikipedia.org/wiki/Winnow_(algorithm) 144 145% $Id$ 146% Copyright (C) 2005,2006,2007,2008,2009,2010 by Alois Schloegl <alois.schloegl@gmail.com> 147% This function is part of the NaN-toolbox 148% http://pub.ist.ac.at/~schloegl/matlab/NaN/ 149 150% This program is free software; you can redistribute it and/or 151% modify it under the terms of the GNU General Public License 152% as published by the Free Software Foundation; either version 3 153% of the License, or (at your option) any later version. 154% 155% This program is distributed in the hope that it will be useful, 156% but WITHOUT ANY WARRANTY; without even the implied warranty of 157% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 158% GNU General Public License for more details. 159% 160% You should have received a copy of the GNU General Public License 161% along with this program; if not, write to the Free Software 162% Foundation, Inc., 51 Franklin Street - Fifth Floor, Boston, MA 02110-1301, USA. 163 164 if nargin<2, 165 error('insufficient input arguments\n\tusage: train_sc(D,C,...)\n'); 166 end 167 if nargin<3, MODE = 'LDA'; end 168 if nargin<4, W = []; end 169 if ischar(MODE) 170 tmp = MODE; 171 clear MODE; 172 MODE.TYPE = tmp; 173 elseif ~isfield(MODE,'TYPE') 174 MODE.TYPE=''; 175 end 176 177 if isfield(MODE,'hyperparameters') && ~isfield(MODE,'hyperparameter'), 178 %% for backwards compatibility, this might become obsolete 179 warning('MODE.hyperparameters are used, You should use MODE.hyperparameter instead!!!'); 180 MODE.hyperparameter = MODE.hyperparameters; 181 end 182 183 sz = size(D); 184 if sz(1)~=size(classlabel,1), 185 error('length of data and classlabel does not fit'); 186 end 187 188 % remove all NaN's 189 if 1, 190 % several classifier can deal with NaN's, there is no need to remove them. 191 elseif isempty(W) 192 %% TODO: some classifiers can deal with NaN's in D. Test whether this can be relaxed. 193 %ix = any(isnan([classlabel]),2); 194 ix = any(isnan([D,classlabel]),2); 195 D(ix,:) = []; 196 classlabel(ix,:)=[]; 197 W = []; 198 else 199 %ix = any(isnan([classlabel]),2); 200 ix = any(isnan([D,classlabel]),2); 201 D(ix,:)=[]; 202 classlabel(ix,:)=[]; 203 W(ix,:)=[]; 204 warning('support for weighting of samples is still experimental'); 205 end 206 207 sz = size(D); 208 if sz(1)~=length(classlabel), 209 error('length of data and classlabel does not fit'); 210 end 211 if ~isfield(MODE,'hyperparameter') 212 MODE.hyperparameter = []; 213 end 214 215 if 0, 216 ; 217 elseif ~isempty(strfind(lower(MODE.TYPE),'/delet')) 218 POS1 = find(MODE.TYPE=='/'); 219 [rix,cix] = row_col_deletion(D); 220 if ~isempty(W), W=W(rix); end 221 CC = train_sc(D(rix,cix),classlabel(rix,:),MODE.TYPE(1:POS1(1)-1),W); 222 CC.G = sparse(cix, 1:length(cix), 1, size(D,2), length(cix)); 223 if isfield(CC,'weights') 224 W = [CC.weights(1,:); CC.weights(2:end,:)]; 225 CC.weights = sparse(size(D,2)+1, size(W,2)); 226 CC.weights([1,cix+1],:) = W; 227 CC.datatype = ['classifier:statistical:',lower(MODE.TYPE)]; 228 else 229 CC.datatype = [CC.datatype,'/delet']; 230 end 231 232 elseif ~isempty(strfind(lower(MODE.TYPE),'nbpw')) 233 error('NBPW not implemented yet') 234 %%%% Naive Bayesian Parzen Window Classifier. 235 [classlabel,CC.Labels] = CL1M(classlabel); 236 for k = 1:length(CC.Labels), 237 [d,CC.MEAN(k,:)] = center(D(classlabel==CC.Labels(k),:),1); 238 [CC.VAR(k,:),CC.N(k,:)] = sumskipnan(d.^2,1); 239 h2_opt = (4./(3*CC.N(k,:))).^(2/5).*CC.VAR(k,:); 240 %%% TODO 241 end 242 243 244 elseif ~isempty(strfind(lower(MODE.TYPE),'nbc')) 245 %%%% Naive Bayesian Classifier 246 if ~isempty(strfind(lower(MODE.TYPE),'anbc')) 247 %%%% Augmented Naive Bayesian classifier. 248 [CC.V,L] = eig(covm(D,'M',W)); 249 D = D*CC.V; 250 else 251 CC.V = eye(size(D,2)); 252 end 253 [classlabel,CC.Labels] = CL1M(classlabel); 254 for k = 1:length(CC.Labels), 255 ix = classlabel==CC.Labels(k); 256 %% [d,CC.MEAN(k,:)] = center(D(ix,:),1); 257 if ~isempty(W) 258 [s,n] = sumskipnan(D(ix,:),1,W(ix)); 259 CC.MEAN(k,:) = s./n; 260 d = D(ix,:) - CC.MEAN(repmat(k,sum(ix),1),:); 261 [CC.VAR(k,:),CC.N(k,:)] = sumskipnan(d.^2,1,W(ix)); 262 else 263 [s,n] = sumskipnan(D(ix,:),1); 264 CC.MEAN(k,:) = s./n; 265 d = D(ix,:) - CC.MEAN(repmat(k,sum(ix),1),:); 266 [CC.VAR(k,:),CC.N(k,:)] = sumskipnan(d.^2,1); 267 end 268 end 269 CC.VAR = CC.VAR./max(CC.N-1,0); 270 CC.datatype = ['classifier:',lower(MODE.TYPE)]; 271 272 273 elseif ~isempty(strfind(lower(MODE.TYPE),'lpm')) 274 if ~isempty(W) 275 error('Error TRAIN_SC: Classifier (%s) does not support weighted samples.',MODE.TYPE); 276 end 277 % linear programming machine 278 % CPLEX optimizer: ILOG solver, ilog cplex 6.5 reference manual http://www.ilog.com 279 MODE.TYPE = 'LPM'; 280 if ~isfield(MODE.hyperparameter,'c_value') 281 MODE.hyperparameter.c_value = 1; 282 end 283 [classlabel,CC.Labels] = CL1M(classlabel); 284 285 M = length(CC.Labels); 286 if M==2, M=1; end % For a 2-class problem, only 1 Discriminant is needed 287 for k = 1:M, 288 %LPM = train_LPM(D,(classlabel==CC.Labels(k)),'C',MODE.hyperparameter.c_value); 289 LPM = train_LPM(D',(classlabel'==CC.Labels(k))); 290 CC.weights(:,k) = [-LPM.b; LPM.w(:)]; 291 end 292 CC.hyperparameter.c_value = MODE.hyperparameter.c_value; 293 CC.datatype = ['classifier:',lower(MODE.TYPE)]; 294 295 296 elseif ~isempty(strfind(lower(MODE.TYPE),'pla')), 297 % Perceptron Learning Algorithm 298 299 [rix,cix] = row_col_deletion(D); 300 [CL101,CC.Labels] = cl101(classlabel); 301 M = size(CL101,2); 302 weights = sparse(length(cix)+1,M); 303 304 %ix = randperm(size(D,1)); %% randomize samples ??? 305 if ~isfield(MODE.hyperparameter,'alpha') 306 if isfield(MODE.hyperparameter,'alpha') 307 alpha = MODE.hyperparameter.alpha; 308 else 309 alpha = 1; 310 end 311 for k = rix(:)', 312 %e = ((classlabel(k)==(1:M))-.5) - sign([1, D(k,cix)] * weights)/2; 313 e = CL101(k,:) - sign([1, D(k,cix)] * weights); 314 weights = weights + alpha * [1,D(k,cix)]' * e ; 315 end 316 317 else %if ~isempty(W) 318 if isfield(MODE.hyperparameter,'alpha') 319 W = W*MODE.hyperparameter.alpha; 320 end 321 for k = rix(:)', 322 %e = ((classlabel(k)==(1:M))-.5) - sign([1, D(k,cix)] * weights)/2; 323 e = CL101(k,:) - sign([1, D(k,cix)] * weights); 324 weights = weights + W(k) * [1,D(k,cix)]' * e ; 325 end 326 end 327 CC.weights = sparse(size(D,2)+1,M); 328 CC.weights([1,cix+1],:) = weights; 329 CC.datatype = ['classifier:',lower(MODE.TYPE)]; 330 331 332 elseif ~isempty(strfind(lower(MODE.TYPE),'adaline')) || ~isempty(strfind(lower(MODE.TYPE),'lms')), 333 % adaptive linear elemente, least mean squares, delta rule, Widrow-Hoff, 334 335 [rix,cix] = row_col_deletion(D); 336 [CL101,CC.Labels] = cl101(classlabel); 337 M = size(CL101,2); 338 weights = sparse(length(cix)+1,M); 339 340 %ix = randperm(size(D,1)); %% randomize samples ??? 341 if isempty(W) 342 if isfield(MODE.hyperparameter,'alpha') 343 alpha = MODE.hyperparameter.alpha; 344 else 345 alpha = 1; 346 end 347 for k = rix(:)', 348 %e = (classlabel(k)==(1:M)) - [1, D(k,cix)] * weights; 349 e = CL101(k,:) - sign([1, D(k,cix)] * weights); 350 weights = weights + alpha * [1,D(k,cix)]' * e ; 351 end 352 353 else %if ~isempty(W) 354 if isfield(MODE.hyperparameter,'alpha') 355 W = W*MODE.hyperparameter.alpha; 356 end 357 for k = rix(:)', 358 %e = (classlabel(k)==(1:M)) - [1, D(k,cix)] * weights; 359 e = CL101(k,:) - sign([1, D(k,cix)] * weights); 360 weights = weights + W(k) * [1,D(k,cix)]' * e ; 361 end 362 end 363 CC.weights = sparse(size(D,2)+1,M); 364 CC.weights([1,cix+1],:) = weights; 365 CC.datatype = ['classifier:',lower(MODE.TYPE)]; 366 367 368 elseif ~isempty(strfind(lower(MODE.TYPE),'winnow')) 369 % winnow algorithm 370 if ~isempty(W) 371 error('Classifier (%s) does not support weighted samples.',MODE.TYPE); 372 end 373 374 [rix,cix] = row_col_deletion(D); 375 [CL101,CC.Labels] = cl101(classlabel); 376 M = size(CL101,2); 377 weights = ones(length(cix),M); 378 theta = size(D,2)/2; 379 380 for k = rix(:)', 381 e = CL101(k,:) - sign(D(k,cix) * weights - theta); 382 weights = weights.* 2.^(D(k,cix)' * e); 383 end 384 385 CC.weights = sparse(size(D,2)+1,M); 386 CC.weights(cix+1,:) = weights; 387 CC.datatype = ['classifier:',lower(MODE.TYPE)]; 388 389 elseif ~isempty(strfind(lower(MODE.TYPE),'pls')) || ~isempty(strfind(lower(MODE.TYPE),'reg')) 390 % 4th version: support for weighted samples - work well with unequally distributed data: 391 % regression analysis, can handle sparse data, too. 392 393 if nargin<4, 394 W = []; 395 end 396 [rix, cix] = row_col_deletion(D); 397 wD = [ones(length(rix),1),D(rix,cix)]; 398 399 if ~isempty(W) 400 %% wD = diag(W)*wD 401 W = W(:); 402 for k=1:size(wD,2) 403 wD(:,k) = W(rix).*wD(:,k); 404 end 405 end 406 [CL101, CC.Labels] = cl101(classlabel(rix,:)); 407 M = size(CL101,2); 408 CC.weights = sparse(sz(2)+1,M); 409 410 %[rix, cix] = row_col_deletion(wD); 411 [q,r] = qr(wD,0); 412 413 if isempty(W) 414 CC.weights([1,cix+1],:) = r\(q'*CL101); 415 else 416 CC.weights([1,cix+1],:) = r\(q'*(W(rix,ones(1,M)).*CL101)); 417 end 418 %for k = 1:M, 419 % CC.weights(cix,k) = r\(q'*(W.*CL101(rix,k))); 420 %end 421 CC.datatype = ['classifier:statistical:',lower(MODE.TYPE)]; 422 423 424 elseif ~isempty(strfind(MODE.TYPE,'WienerHopf')) 425 % Q: equivalent to LDA 426 % equivalent to Regression, except regression can not deal with NaN's 427 [CL101,CC.Labels] = cl101(classlabel); 428 M = size(CL101,2); 429 CC.weights = sparse(size(D,2)+1,M); 430 cc = covm(D,'E',W); 431 %c1 = classlabel(~isnan(classlabel)); 432 %c2 = ones(sum(~isnan(classlabel)),M); 433 %for k = 1:M, 434 % c2(:,k) = c1==CC.Labels(k); 435 %end 436 %CC.weights = cc\covm([ones(size(c2,1),1),D(~isnan(classlabel),:)],2*real(c2)-1,'M',W); 437 CC.weights = cc\covm([ones(size(D,1),1),D],CL101,'M',W); 438 CC.datatype = ['classifier:statistical:',lower(MODE.TYPE)]; 439 440 441 elseif ~isempty(strfind(lower(MODE.TYPE),'/gsvd')) 442 if ~isempty(W) 443 error('Classifier (%s) does not support weighted samples.',MODE.TYPE); 444 end 445 % [2] Peg Howland and Haesun Park, 2004 446 % Generalizing Discriminant Analysis Using the Generalized Singular Value Decomposition 447 % IEEE Transactions on Pattern Analysis and Machine Intelligence, 26(8), 2004. 448 % dx.doi.org/10.1109/TPAMI.2004.46 449 % [3] http://www-static.cc.gatech.edu/~kihwan23/face_recog_gsvd.htm 450 451 [classlabel,CC.Labels] = CL1M(classlabel); 452 [rix,cix] = row_col_deletion(D); 453 454 Hw = zeros(length(rix)+length(CC.Labels), length(cix)); 455 Hb = []; 456 m0 = mean(D(rix,cix)); 457 K = length(CC.Labels); 458 N = zeros(1,K); 459 for k = 1:K, 460 ix = find(classlabel(rix)==CC.Labels(k)); 461 N(k) = length(ix); 462 [Hw(ix,:), mu] = center(D(rix(ix),cix)); 463 %Hb(k,:) = sqrt(N(k))*(mu(k,:)-m0); 464 Hw(length(rix)+k,:) = sqrt(N(k))*(mu-m0); % Hb(k,:) 465 end 466 try 467 [P,R,Q] = svd(Hw,'econ'); 468 catch % needed because SVD(..,'econ') not supported in Matlab 6.x 469 [P,R,Q] = svd(Hw,0); 470 end 471 t = rank(R); 472 473 clear Hw Hb mu; 474 %[size(D);size(P);size(Q);size(R)] 475 R = R(1:t,1:t); 476 %P = P(1:size(D,1),1:t); 477 %Q = Q(1:t,:); 478 [U,E,W] = svd(P(1:length(rix),1:t),0); 479 %[size(U);size(E);size(W)] 480 clear U E P; 481 %[size(Q);size(R);size(W)] 482 483 %G = Q(1:t,:)'*[R\W']; 484 G = Q(:,1:t)*(R\W'); % this works as well and needs only 'econ'-SVD 485 %G = G(:,1:t); % not needed 486 487 % do not use this, gives very bad results for Medline database 488 %G = G(:,1:K); this seems to be a typo in [2] and [3]. 489 CC = train_sc(D(:,cix)*G,classlabel,MODE.TYPE(1:find(MODE.TYPE=='/')-1)); 490 CC.G = sparse(size(D,2),size(G,2)); 491 CC.G(cix,:) = G; 492 if isfield(CC,'weights') 493 CC.weights = sparse([CC.weights(1,:); CC.G*CC.weights(2:end,:)]); 494 CC.datatype = ['classifier:statistical:', lower(MODE.TYPE)]; 495 else 496 CC.datatype = [CC.datatype,'/gsvd']; 497 end 498 499 500 elseif ~isempty(strfind(lower(MODE.TYPE),'sparse')) 501 if ~isempty(W) 502 error('Classifier (%s) does not support weighted samples.',MODE.TYPE); 503 end 504 % [5] J.D. Tebbens and P.Schlesinger (2006), 505 % Improving Implementation of Linear Discriminant Analysis for the Small Sample Size Problem 506 % http://www.cs.cas.cz/mweb/download/publi/JdtSchl2006.pdf 507 508 [classlabel,CC.Labels] = CL1M(classlabel); 509 [rix,cix] = row_col_deletion(D); 510 511 warning('sparse LDA is sensitive to linear transformations') 512 M = length(CC.Labels); 513 G = sparse([],[],[],length(rix),M,length(rix)); 514 for k = 1:M, 515 G(classlabel(rix)==CC.Labels(k),k) = 1; 516 end 517 tol = 1e-10; 518 519 G = train_lda_sparse(D(rix,cix),G,1,tol); 520 CC.datatype = 'classifier:slda'; 521 POS1 = find(MODE.TYPE=='/'); 522 %G = v(:,1:size(G.trafo,2)).*G.trafo; 523 %CC.weights = s * CC.weights(2:end,:) + sparse(1,1:M,CC.weights(1,:),sz(2)+1,M); 524 525 CC = train_sc(D(rix,cix)*G.trafo,classlabel(rix),MODE.TYPE(1:POS1(1)-1)); 526 CC.G = sparse(size(D,2),size(G.trafo,2)); 527 CC.G(cix,:) = G.trafo; 528 if isfield(CC,'weights') 529 CC.weights = sparse([CC.weights(1,:); CC.G*CC.weights(2:end,:)]); 530 CC.datatype = ['classifier:statistical:',lower(MODE.TYPE)]; 531 else 532 CC.datatype = [CC.datatype,'/sparse']; 533 end 534 535 elseif ~isempty(strfind(lower(MODE.TYPE),'rbf')) 536 if ~isempty(W) 537 error('Classifier (%s) does not support weighted samples.',MODE.TYPE); 538 end 539 540 % Martin Hieden's RBF-SVM 541 if exist('svmpredict_mex','file')==3, 542 MODE.TYPE = 'SVM:LIB:RBF'; 543 else 544 error('No SVM training algorithm available. Install LibSVM for Matlab.\n'); 545 end 546 CC.options = '-t 2 -q'; %use RBF kernel, set C, set gamma 547 if isfield(MODE.hyperparameter,'gamma') 548 CC.options = sprintf('%s -c %g', CC.options, MODE.hyperparameter.c_value); % set C 549 end 550 if isfield(MODE.hyperparameter,'c_value') 551 CC.options = sprintf('%s -g %g', CC.options, MODE.hyperparameter.gamma); % set C 552 end 553 554 % pre-whitening 555 [D,r,m]=zscore(D,1); 556 CC.prewhite = sparse(2:sz(2)+1,1:sz(2),r,sz(2)+1,sz(2),2*sz(2)); 557 CC.prewhite(1,:) = -m.*r; 558 559 [classlabel,CC.Labels] = CL1M(classlabel); 560 CC.model = svmtrain_mex(classlabel, D, CC.options); % Call the training mex File 561 CC.datatype = ['classifier:',lower(MODE.TYPE)]; 562 563 564 elseif ~isempty(strfind(lower(MODE.TYPE),'svm11')) 565 if ~isempty(W) 566 error('Classifier (%s) does not support weighted samples.',MODE.TYPE); 567 end 568 % 1-versus-1 scheme 569 if ~isfield(MODE.hyperparameter,'c_value') 570 MODE.hyperparameter.c_value = 1; 571 end 572 573 CC.options=sprintf('-c %g -t 0 -q',MODE.hyperparameter.c_value); %use linear kernel, set C 574 CC.hyperparameter.c_value = MODE.hyperparameter.c_value; 575 576 % pre-whitening 577 [D,r,m]=zscore(D,1); 578 CC.prewhite = sparse(2:sz(2)+1,1:sz(2),r,sz(2)+1,sz(2),2*sz(2)); 579 CC.prewhite(1,:) = -m.*r; 580 581 [classlabel,CC.Labels] = CL1M(classlabel); 582 CC.model = svmtrain_mex(classlabel, D, CC.options); % Call the training mex File 583 584 FUN = 'SVM:LIB:1vs1'; 585 CC.datatype = ['classifier:',lower(FUN)]; 586 587 588 elseif ~isempty(strfind(lower(MODE.TYPE),'psvm')) 589 if ~isempty(W) 590 %%% error('Classifier (%s) does not support weighted samples.',MODE.TYPE); 591 warning('Classifier (%s) in combination with weighted samples is not tested.',MODE.TYPE); 592 end 593 if ~isfield(MODE,'hyperparameter') 594 nu = 1; 595 elseif isfield(MODE.hyperparameter,'nu') 596 nu = MODE.hyperparameter.nu; 597 else 598 nu = 1; 599 end 600 [m,n] = size(D); 601 [CL101,CC.Labels] = cl101(classlabel); 602 CC.weights = sparse(n+1,size(CL101,2)); 603 M = size(CL101,2); 604 for k = 1:M, 605 d = sparse(1:m,1:m,CL101(:,k)); 606 H = d * [ones(m,1),D]; 607 %%% r = sum(H,1)'; 608 r = sumskipnan(H,1,W)'; 609 %%% r = (speye(n+1)/nu + H' * H)\r; %solve (I/nu+H’*H)r=H’*e 610 [HTH, nn] = covm(H,H,'M',W); 611 r = (speye(n+1)/nu + HTH)\r; %solve (I/nu+H’*H)r=H’*e 612 u = nu*(1-(H*r)); 613 %%% CC.weights(:,k) = u'*H; 614 [c,nn] = covm(u,H,'M',W); 615 CC.weights(:,k) = c'; 616 end 617 CC.hyperparameter.nu = nu; 618 CC.datatype = ['classifier:',lower(MODE.TYPE)]; 619 620 elseif ~isempty(strfind(lower(MODE.TYPE),'svm:lin4')) 621 if ~isfield(MODE.hyperparameter,'c_value') 622 MODE.hyperparameter.c_value = 1; 623 end 624 625 [classlabel,CC.Labels] = CL1M(classlabel); 626 M = length(CC.Labels); 627 CC.weights = sparse(size(D,2)+1,M); 628 629 [rix,cix] = row_col_deletion(D); 630 631 % pre-whitening 632 [D,r,m]=zscore(D(rix,cix),1); 633 sz2 = length(cix); 634 s = sparse(2:sz2+1,1:sz2,r,sz2+1,sz2,2*sz2); 635 s(1,:) = -m.*r; 636 637 CC.options = sprintf('-s 4 -B 1 -c %f -q', MODE.hyperparameter.c_value); % C-SVC, C=1, linear kernel, degree = 1, 638 model = train(W, classlabel, sparse(D), CC.options); % C-SVC, C=1, linear kernel, degree = 1, 639 weights = model.w([end,1:end-1],:)'; 640 641 CC.weights([1,cix+1],:) = s * weights(2:end,:) + sparse(1,1:M,weights(1,:),sz2+1,M); % include pre-whitening transformation 642 CC.weights([1,cix+1],:) = s * CC.weights(cix+1,:) + sparse(1,1:M,CC.weights(1,:),sz2+1,M); % include pre-whitening transformation 643 CC.hyperparameter.c_value = MODE.hyperparameter.c_value; 644 CC.datatype = ['classifier:',lower(MODE.TYPE)]; 645 646 647 elseif ~isempty(strfind(lower(MODE.TYPE),'svm')) 648 649 if ~isfield(MODE.hyperparameter,'c_value') 650 MODE.hyperparameter.c_value = 1; 651 end 652 if any(MODE.TYPE==':'), 653 % nothing to be done 654 elseif exist('train','file')==3, 655 MODE.TYPE = 'SVM:LIN'; %% liblinear 656 elseif exist('svmtrain_mex','file')==3, 657 MODE.TYPE = 'SVM:LIB'; 658 elseif (exist('svmtrain','file')==3), 659 MODE.TYPE = 'SVM:LIB'; 660 fprintf(1,'You need to rename %s to svmtrain_mex.mex !! \n Press any key to continue !!!\n',which('svmtrain.mex')); 661 elseif exist('svmtrain','file')==2, 662 MODE.TYPE = 'SVM:bioinfo'; 663 elseif exist('mexSVMTrain','file')==3, 664 MODE.TYPE = 'SVM:OSU'; 665 elseif exist('svcm_train','file')==2, 666 MODE.TYPE = 'SVM:LOO'; 667 elseif exist('svmclass','file')==2, 668 MODE.TYPE = 'SVM:KM'; 669 elseif exist('svc','file')==2, 670 MODE.TYPE = 'SVM:Gunn'; 671 else 672 error('No SVM training algorithm available. Install OSV-SVM, or LOO-SVM, or libSVM for Matlab.\n'); 673 end 674 675 %%CC = train_svm(D,classlabel,MODE); 676 [CL101,CC.Labels] = cl101(classlabel); 677 M = size(CL101,2); 678 [rix,cix] = row_col_deletion(D); 679 CC.weights = sparse(sz(2)+1, M); 680 681 % pre-whitening 682 [D,r,m]=zscore(D(rix,cix),1); 683 sz2 = length(cix); 684 s = sparse(2:sz2+1,1:sz2,r,sz2+1,sz2,2*sz2); 685 s(1,:) = -m.*r; 686 687 for k = 1:M, 688 cl = CL101(rix,k); 689 if strncmp(MODE.TYPE, 'SVM:LIN',7); 690 if isfield(MODE,'options') 691 CC.options = MODE.options; 692 else 693 t = 0; 694 if length(MODE.TYPE)>7, t=MODE.TYPE(8)-'0'; end 695 if (t<0 || t>6) t=0; end 696 CC.options = sprintf('-s %i -B 1 -c %f -q',t, MODE.hyperparameter.c_value); % C-SVC, C=1, linear kernel, degree = 1, 697 end 698 model = train(W, cl, sparse(D), CC.options); % C-SVC, C=1, linear kernel, degree = 1, 699 w = -model.w'; 700 Bias = -model.bias; 701 w = -model.w(:,1:end-1)'; 702 Bias = -model.w(:,end)'; 703 704 elseif strcmp(MODE.TYPE, 'SVM:LIB'); %% tested with libsvm-mat-2.9-1 705 if isfield(MODE,'options') 706 CC.options = MODE.options; 707 else 708 CC.options = sprintf('-s 0 -c %f -t 0 -d 1 -q', MODE.hyperparameter.c_value); % C-SVC, C=1, linear kernel, degree = 1, 709 end 710 model = svmtrain_mex(cl, D, CC.options); % C-SVC, C=1, linear kernel, degree = 1, 711 w = cl(1) * model.SVs' * model.sv_coef; %Calculate decision hyperplane weight vector 712 % ensure correct sign of weight vector and Bias according to class label 713 Bias = model.rho * cl(1); 714 715 elseif strcmp(MODE.TYPE, 'SVM:bioinfo'); 716 % SVM classifier from bioinformatics toolbox. 717 % Settings suggested by Ian Daly, 2011-06-06 718 options = optimset('Display','iter','maxiter',20000, 'largescale','off'); 719 CC.SVMstruct = svmtrain(D, cl, 'AUTOSCALE', 0, 'quadprog_opts', options, 'Method', 'LS', 'kernel_function', 'polynomial'); 720 Bias = -CC.SVMstruct.Bias; 721 w = -CC.SVMstruct.Alpha'*CC.SVMstruct.SupportVectors; 722 723 elseif strcmp(MODE.TYPE, 'SVM:OSU'); 724 [AlphaY, SVs, Bias] = mexSVMTrain(D', cl', [0 1 1 1 MODE.hyperparameter.c_value]); % Linear Kernel, C=1; degree=1, c-SVM 725 w = -SVs * AlphaY'*cl(1); %Calculate decision hyperplane weight vector 726 % ensure correct sign of weight vector and Bias according to class label 727 Bias = -Bias * cl(1); 728 729 elseif strcmp(MODE.TYPE, 'SVM:LOO'); 730 [a, Bias, g, inds] = svcm_train(D, cl, MODE.hyperparameter.c_value); % C = 1; 731 w = D(inds,:)' * (a(inds).*cl(inds)) ; 732 733 elseif strcmp(MODE.TYPE, 'SVM:Gunn'); 734 [nsv, alpha, Bias,svi] = svc(D, cl, 1, MODE.hyperparameter.c_value); % linear kernel, C = 1; 735 w = D(svi,:)' * alpha(svi) * cl(1); 736 Bias = mean(D*w); 737 738 elseif strcmp(MODE.TYPE, 'SVM:KM'); 739 [xsup,w1,Bias,inds] = svmclass(D, cl, MODE.hyperparameter.c_value, 1, 'poly', 1); % C = 1; 740 w = -D(inds,:)' * w1; 741 742 else 743 fprintf(2,'Error TRAIN_SVM: no SVM training algorithm available\n'); 744 return; 745 end 746 747 CC.weights(1,k) = -Bias; 748 CC.weights(cix+1,k) = w; 749 end 750 CC.weights([1,cix+1],:) = s * CC.weights(cix+1,:) + sparse(1,1:M,CC.weights(1,:),sz2+1,M); % include pre-whitening transformation 751 CC.hyperparameter.c_value = MODE.hyperparameter.c_value; 752 CC.datatype = ['classifier:',lower(MODE.TYPE)]; 753 754 755 elseif ~isempty(strfind(lower(MODE.TYPE),'csp')) 756 CC.datatype = ['classifier:',lower(MODE.TYPE)]; 757 [classlabel,CC.Labels] = CL1M(classlabel); 758 CC.MD = repmat(NaN,[sz(2)+[1,1],length(CC.Labels)]); 759 CC.NN = CC.MD; 760 for k = 1:length(CC.Labels), 761 %% [CC.MD(k,:,:),CC.NN(k,:,:)] = covm(D(classlabel==CC.Labels(k),:),'E'); 762 ix = classlabel==CC.Labels(k); 763 if isempty(W) 764 [CC.MD(:,:,k),CC.NN(:,:,k)] = covm(D(ix,:), 'E'); 765 else 766 [CC.MD(:,:,k),CC.NN(:,:,k)] = covm(D(ix,:), 'E', W(ix)); 767 end 768 end 769 ECM = CC.MD./CC.NN; 770 W = csp(ECM,'CSP3'); 771 %%% ### This is a hack ### 772 CC.FiltA = 50; 773 CC.FiltB = ones(CC.FiltA,1); 774 d = filtfilt(CC.FiltB,CC.FiltA,(D*W).^2); 775 CC.csp_w = W; 776 CC.CSP = train_sc(log(d),classlabel); 777 778 779 else % Linear and Quadratic statistical classifiers 780 CC.datatype = ['classifier:statistical:',lower(MODE.TYPE)]; 781 [classlabel,CC.Labels] = CL1M(classlabel); 782 CC.MD = repmat(NaN,[sz(2)+[1,1],length(CC.Labels)]); 783 CC.NN = CC.MD; 784 for k = 1:length(CC.Labels), 785 ix = classlabel==CC.Labels(k); 786 if isempty(W) 787 [CC.MD(:,:,k),CC.NN(:,:,k)] = covm(D(ix,:), 'E'); 788 else 789 [CC.MD(:,:,k),CC.NN(:,:,k)] = covm(D(ix,:), 'E', W(ix)); 790 end 791 end 792 793 ECM = CC.MD./CC.NN; 794 NC = size(CC.MD); 795 if strncmpi(MODE.TYPE,'LD',2) || strncmpi(MODE.TYPE,'FDA',3) || strncmpi(MODE.TYPE,'FLDA',3), 796 797 %if NC(1)==2, NC(1)=1; end % linear two class problem needs only one discriminant 798 CC.weights = repmat(NaN,NC(2),NC(3)); % memory allocation 799 type = MODE.TYPE(3)-'0'; 800 801 ECM0 = squeeze(sum(ECM,3)); %decompose ECM 802 for k = 1:NC(3); 803 ix = [1:k-1,k+1:NC(3)]; 804 dM = CC.MD(:,1,k)./CC.NN(:,1,k) - sum(CC.MD(:,1,ix),3)./sum(CC.NN(:,1,ix),3); 805 switch (type) 806 case 2 % LD2 807 ecm0 = (sum(ECM(:,:,ix),3)/(NC(3)-1) + ECM(:,:,k)); 808 case 4 % LD4 809 ecm0 = 2*(sum(ECM(:,:,ix),3) + ECM(:,:,k))/NC(3); 810 % ecm0 = sum(CC.MD,3)./sum(CC.NN,3); 811 case 5 % LD5 812 ecm0 = ECM(:,:,k); 813 case 6 % LD6 814 ecm0 = sum(CC.MD(:,:,ix),3)./sum(CC.NN(:,:,ix),3); 815 otherwise % LD3, LDA, FDA 816 ecm0 = ECM0; 817 end 818 if isfield(MODE.hyperparameter,'gamma') 819 ecm0 = ecm0 + mean(diag(ecm0))*eye(size(ecm0))*MODE.hyperparameter.gamma; 820 end 821 822 CC.weights(:,k) = ecm0\dM; 823 824 end 825 %CC.weights = sparse(CC.weights); 826 827 elseif strcmpi(MODE.TYPE,'RDA'); 828 if isfield(MODE,'hyperparameter') 829 CC.hyperparameter = MODE.hyperparameter; 830 end 831 % default values 832 if ~isfield(CC.hyperparameter,'gamma') 833 CC.hyperparameter.gamma = 0; 834 end 835 if ~isfield(CC.hyperparameter,'lambda') 836 CC.hyperparameter.lambda = 1; 837 end 838 else 839 ECM0 = sum(ECM,3); 840 nn = ECM0(1,1,1); % number of samples in training set for class k 841 XC = squeeze(ECM0(:,:,1))/nn; % normalize correlation matrix 842 M = XC(1,2:NC(2)); % mean 843 S = XC(2:NC(2),2:NC(2)) - M'*M;% covariance matrix 844 845 try 846 [v,d]=eig(S); 847 U0 = v(diag(d)==0,:); 848 CC.iS2 = U0*U0'; 849 end 850 851 %M = M/nn; S=S/(nn-1); 852 ICOV0 = inv(S); 853 CC.iS0 = ICOV0; 854 % ICOV1 = zeros(size(S)); 855 for k = 1:NC(3), 856 %[M,sd,S,xc,N] = decovm(ECM{k}); %decompose ECM 857 %c = size(ECM,2); 858 nn = ECM(1,1,k);% number of samples in training set for class k 859 XC = squeeze(ECM(:,:,k))/nn;% normalize correlation matrix 860 M = XC(1,2:NC(2));% mean 861 S = XC(2:NC(2),2:NC(2)) - M'*M;% covariance matrix 862 %M = M/nn; S=S/(nn-1); 863 864 %ICOV(1) = ICOV(1) + (XC(2:NC(2),2:NC(2)) - )/nn 865 866 CC.M{k} = M; 867 CC.IR{k} = [-M;eye(NC(2)-1)]*inv(S)*[-M',eye(NC(2)-1)]; % inverse correlation matrix extended by mean 868 CC.IR0{k} = [-M;eye(NC(2)-1)]*ICOV0*[-M',eye(NC(2)-1)]; % inverse correlation matrix extended by mean 869 d = NC(2)-1; 870 if exist('OCTAVE_VERSION','builtin') 871 S = full(S); 872 end 873 CC.logSF(k) = log(nn) - d/2*log(2*pi) - det(S)/2; 874 CC.logSF2(k) = -2*log(nn/sum(ECM(:,1,1))); 875 CC.logSF3(k) = d*log(2*pi) + log(det(S)); 876 CC.logSF4(k) = log(det(S)) + 2*log(nn); 877 CC.logSF5(k) = log(det(S)); 878 CC.logSF6(k) = log(det(S)) - 2*log(nn/sum(ECM(:,1,1))); 879 CC.logSF7(k) = log(det(S)) + d*log(2*pi) - 2*log(nn/sum(ECM(:,1,1))); 880 CC.logSF8(k) = sum(log(svd(S))) + log(nn) - log(sum(ECM(:,1,1))); 881 CC.SF(k) = nn/sqrt((2*pi)^d * det(S)); 882 %CC.datatype='LLBC'; 883 end 884 end 885 end 886end 887 888function [CL101,Labels] = cl101(classlabel) 889 %% convert classlabels to {-1,1} encoding 890 891 if (all(classlabel>=0) && all(classlabel==fix(classlabel)) && (size(classlabel,2)==1)) 892 M = max(classlabel); 893 if M==2, 894 CL101 = (classlabel==2)-(classlabel==1); 895 else 896 CL101 = zeros(size(classlabel,1),M); 897 for k=1:M, 898 %% One-versus-Rest scheme 899 CL101(:,k) = 2*real(classlabel==k) - 1; 900 end 901 end 902 CL101(isnan(classlabel),:) = NaN; %% or zero ??? 903 904 elseif all((classlabel==1) | (classlabel==-1) | (classlabel==0) ) 905 CL101 = classlabel; 906 M = size(CL101,2); 907 else 908 classlabel, 909 error('format of classlabel unsupported'); 910 end 911 Labels = 1:M; 912 return; 913end 914 915 916function [cl1m, Labels] = CL1M(classlabel) 917 %% convert classlabels to 1..M encoding 918 if (all(classlabel>=0) && all(classlabel==fix(classlabel)) && (size(classlabel,2)==1)) 919 cl1m = classlabel; 920 921 elseif all((classlabel==1) | (classlabel==-1) | (classlabel==0) ) 922 CL101 = classlabel; 923 M = size(classlabel,2); 924 if any(sum(classlabel==1,2)>1) 925 warning('invalid format of classlabel - at most one category may have +1'); 926 end 927 if (M==1), 928 cl1m = (classlabel==-1) + 2*(classlabel==+1); 929 else 930 [tmp, cl1m] = max(classlabel,[],2); 931 if any(tmp ~= 1) 932 warning('some class might not be properly represented - you might what to add another column to classlabel = [max(classlabel,[],2)<1,classlabel]'); 933 end 934 cl1m(tmp<1)= 0; %% or NaN ??? 935 end 936 else 937 classlabel 938 error('format of classlabel unsupported'); 939 end 940 Labels = 1:max(cl1m); 941 return; 942end 943