1function [CC,Q,tsd,md]=findclassifier1(D,TRIG,cl,T,t0,SWITCH) 2% FINDCLASSIFIER1 3% 4% [CC,Q,TSD,MD]=findclassifier1(D,TRIG,Class,class_times,t_ref); 5% 6% D data, each row is one time point 7% TRIG trigger time points 8% Class class information 9% class_times classification times, combinations of times must be in one row 10% t_ref reference time for Class 0 (optional) 11% 12% CC contains LDA and MD classifiers 13% Q is a list of classification quality for each time of 'class_times' 14% TSD returns the LDA classification 15% MD returns the MD classification 16% 17% [CC,Q,TSD,MD]=findclassifier(AR,find(trig>0.5)-257,~mod(1:80,2),reshape(1:14*128,16,14*8)'); 18% 19% 20% Reference(s): 21% [1] Schlögl A., Neuper C. Pfurtscheller G. 22% Estimating the mutual information of an EEG-based Brain-Computer-Interface 23% Biomedizinische Technik 47(1-2): 3-8, 2002. 24% [2] A. Schlögl, C. Keinrath, R. Scherer, G. Pfurtscheller, 25% Information transfer of an EEG-based Bran-computer interface. 26% Proceedings of the 1st International IEEE EMBS Conference on Neural Engineering, Capri, Italy, Mar 20-22, 2003 27 28 29% Copyright (C) 1999-2004 by Alois Schloegl <alois.schloegl@gmail.com> 30% $Id$ 31 32 33% This program is free software; you can redistribute it and/or 34% modify it under the terms of the GNU General Public License 35% as published by the Free Software Foundation; either version 2 36% of the License, or (at your option) any later version. 37% 38% This program is distributed in the hope that it will be useful, 39% but WITHOUT ANY WARRANTY; without even the implied warranty of 40% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 41% GNU General Public License for more details. 42% 43% You should have received a copy of the GNU General Public License 44% along with this program; if not, write to the Free Software 45% Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 46 47 48warning('this function is obsolete and replaced by FINDCLASSIFIER'); 49 50 51tsd=[];md=[]; 52 53if nargin<6, 54 SWITCH=0; 55end; 56if nargin>4, 57 if isempty(t0), 58 t0=logical(ones(size(T,1),1)); 59 end; 60end; 61tmp=cl;tmp(isnan(tmp))=0; 62if any(rem(tmp,1) & ~isnan(cl)), 63 fprintf(2,'Error %s: class information is not integer\n',mfilename); 64 return; 65end; 66if length(TRIG)~=length(cl); 67 fprintf(2,'number of Triggers do not match class information'); 68end; 69 70CL = unique(cl(~isnan(cl))); 71CL = sort(CL); 72TRIG = TRIG(:); 73if ~all(D(:,1)==1) 74 % D1=[ones(size(D,1)-1,1),diff(D)]; 75 D =[ones(size(D,1),1),D]; 76 %else 77 % D1=[ones(size(D,1)-1,1),diff(D(:,2:end))]; 78end; 79 80% add sufficient NaNs at the beginning and the end 81tmp = min(TRIG)+min(min(T))-1; 82if tmp<0, 83 TRIG = TRIG - tmp; 84 D = [repmat(nan,[-tmp,size(D,2)]);D]; 85end; 86tmp = max(TRIG)+max(max(T))-size(D,1); 87if tmp>0, 88 D = [D;repmat(nan,[tmp,size(D,2)])]; 89end; 90 91% estimate classification result for all time segments in T - without crossvalidation 92CMX = zeros([size(T,1),length(CL)*[1,1]]); 93for k = 1:size(T,1), 94 cmx = zeros(length(CL)); 95 for l = 1:length(CL), 96 t = perm(TRIG(cl==CL(l)),T(k,:)); 97 %t = t(t<=size(D,1)); 98 [C{k,l},tmp] = covm(D(t(:),:),'M'); 99 end; 100 %[Q(k),d{k}] = qcmahal({C0r,C{k,:}}); 101 [CC.QC(k),d{k}] = qcmahal({C{k,:}}); 102 lnQ(k) = mean(log(d{k}(~eye(length(d{k}))))); 103 for l = 1:length(CL), 104 t = perm(TRIG(cl==CL(l)),T(k,:)); 105 %t = t(t<=size(D,1)); 106 [tmp] = mdbc({C{k,:}},D(t(:),:)); 107 [tmp,ix] = min(tmp,[],2); 108 tmp = isnan(tmp); 109 ix(tmp) = NaN; %NC(1)+1; % not classified; any value but not 1:length(MD) 110 ix(~tmp) = CL(ix(~tmp)); 111 tmp = histo3([ix;CL(:)]); 112 cmx(tmp.X,l) = tmp.H-1; 113 end; 114 CMX(k,:,:) = cmx; 115 CC.KAPPA(k) = kappa(cmx); 116 CC.ACC(k) = sum(diag(cmx))/sum(cmx(:)); 117end; 118% identify best classification time 119if nargin>4, 120 tmp = CC.QC; 121 tmp(~t0) = 0; 122 [maxQ,CC.TI] = max(tmp); %d{K}, 123else 124 [maxQ,CC.TI] = max(CC.QC); %d{K}, 125end; 126%CC.TI = K; 127 128% build MD classifier 129CC.MD = {C{CC.TI,:}}; 130CC.IR = mdbc({C{CC.TI,:}}); 131CC.D = d{CC.TI}; 132CC.Q = CC.QC(CC.TI); 133CC.CMX = squeeze(CMX(CC.TI,:,:)); 134 135m1=decovm(CC.MD{1}); 136m2=decovm(CC.MD{2}); 137tmp=mdbc(CC.MD,[1,m1;1,m2]); 138CC.scale=[1,1]*max(abs(tmp(:))); % element 1 139 140[maxQ,CC.lnTI] = max(lnQ); %d{K}, 141CC.DistMXln = d{CC.lnTI}; 142CC.MDln = {C{CC.lnTI,:}}; 143 144% alternative classifier using two different time segments. 145if 1, 146 [Q,d] = qcmahal({C{:}}'); 147 CC.T2.D = d; 148 [ix,iy] = find(d==max(d(:))); 149 ix=mod(ix(1)-1,size(C,1))+1; 150 iy=mod(iy(1)-1,size(C,1))+1; 151 CC.T2.TI = [ix(1),iy(1)]; 152 CC.T2.MD = {C{ix(1),1},C{iy(1),2}}; 153end; 154 155% build LDA classifier 156if length(CL)==2, 157 % LDA 158 C0 = zeros(size(C{CC.TI,1})); 159 for l=1:length(CL); 160 [M{l},sd,COV,xc,N,R2] = decovm(C{CC.TI,l}); 161 C0 = C0 + C{CC.TI,l}; 162 end; 163 [M0,sd,COV0,xc,N,R2] = decovm(C0); 164 w = COV0\(M{1}'-M{2}'); 165 w0 = M0*w; 166 %CC.LDA.b = w0; 167 %CC.LDA.w = -w; 168 CC.lda = [w0; -w]; 169 170 % MD 171 tsd = ldbc(CC.MD,D); 172end; 173md = mdbc(CC.MD,D); 174lld= llbc(CC.MD,D); 175 176% bias correction not used anymore. 177CC.BIAS.LDA=0;%mean(tsd); 178CC.BIAS.MDA=0;%mean(diff(-md,[],2)); 179CC.BIAS.GRB=0;%mean(diff(-exp(-md/2),[],2)); 180 181%[tmp,IX] = min(md,[],2); 182 183%% cross-validation with jackknife (leave one trial out) 184nc = max(max(T))-min(min(T))+1; 185 186JKD = repmat(nan,[nc,length(CL),length(TRIG)]); 187JKD1 = repmat(nan,[nc,length(TRIG)]); 188JKD2 = repmat(nan,[nc,length(TRIG)]); 189JKLD = repmat(nan,[nc,length(TRIG)]); 190for l = find(~isnan(cl(:)'));1:length(cl); 191 c = find(cl(l)==CL); 192 t = TRIG(l)+T(CC.TI,:); 193 %t = t(t<=size(D,1)); 194 [tmp,tmpn] = covm(D(t(:),:),'M'); 195 196 cc = CC.MD; 197 cc{c} = CC.MD{c}-tmp; 198 199 %t = TRIG(l)+(1:nc); 200 %t = t(t<=size(D,1)); 201 t = TRIG(l)+(min(min(T)):max(max(T))); 202 203 [d,ix] = llbc(cc,D(t,:)); 204 if length(CL)==2, 205 JKD3(:,l)=d(:,1); 206 JKD4(:,l)=d(:,2); 207 end; 208 209 d = mdbc(cc,D(t,:)); 210 JKD(:,:,l) = d; 211 [tmp,MDIX(:,l)] = min(d,[],2); 212 213 if length(CL)==2, 214 JKD1(:,l) = d(:,1); 215 JKD2(:,l) = d(:,2); 216 217 LDA(:,l) = ldbc(cc); 218 JKLD(:,l) = D(t,:)*LDA(:,l); 219 end; 220end; 221[CC.ldaC0,NN] = covm(LDA','D0'); 222%CC.ldaC0=CC.ldaC0./NN*min(0,sum(~isnan(CL))-1); 223% since NN==min(0,sum(~isnan(CL))-1), no need to rescale 224 225% Concordance matrix with cross-validation 226CC.mmx= zeros([size(MDIX,1),length(CL)^2]); 227CC.I0 = zeros([size(MDIX,1),length(CL)]); 228CC.I = zeros([size(MDIX,1),1]); 229tmp = zeros([size(MDIX,1),length(CL)]); 230for k = 1:length(CL), 231 jkd = squeeze(JKD(:,k,:)); 232 o = bci3eval(jkd(:,cl~=k),jkd(:,cl==k),2); 233 234 CC.TSD{k} = o; 235 CC.I0(:,k) = log2(2*var(jkd,[],2)./(var(jkd(:,cl==k),[],2) + var(jkd(:,cl~=k),[],2)))/2; 236 237 [sum0,n0,ssq0]=sumskipnan(jkd(:,cl==k),2); 238 [sum1,n1,ssq1]=sumskipnan(jkd(:,cl~=k),2); 239 s0 = (ssq0-sum0.*sum0./n0)./(n0-1); 240 s1 = (ssq1-sum1.*sum1./n1)./(n1-1); 241 s = (ssq0+ssq1-(sum0+sum1).*(sum0+sum1)./(n0+n1))./(n0+n1-1); 242 SNR = 2*s./(s0+s1); % this is SNR+1 243 CC.I1(:,k) = log2(SNR)/2; 244 245 for l = 1:length(CL), 246 tmp(:,l) = sum(MDIX(:,cl==CL(k))==CL(l),2); 247 if CL(k) == CL(l), 248 acc = tmp(:,l); 249 end; 250 end; 251 CC.mmx(:,(1-length(CL):0)+k*length(CL)) = tmp; 252 CC.acc(:,k) = acc./sum(tmp,2); 253end; 254CC.CMX00 = reshape(sum(CC.mmx(T(CC.TI,:),:),1),[1,1]*length(CL))/size(T,2); 255CC.I = sum(CC.I0,2); 256CC.ACC00 = sum(CC.mmx(:,1:length(CL)+1:end),2)/sum(~isnan(cl)); 257CC.KAP00 = zeros(size(MDIX,1),1); 258for k = 1:size(MDIX,1), 259 CC.KAP00(k) = kappa(reshape(CC.mmx(k,:),[1,1]*length(CL))); 260end; 261 262if length(CL) > 2, 263 return; 264end; 265 266 267if bitand(SWITCH,1), 268 CC.LDA.ERR00 = (mean(sign(JKLD),2)+1)/2; 269 CC.MDA.ERR00 = (mean(sign(JKD1-JKD2),2)+1)/2; 270 CC.GRB.ERR00 = (mean(sign(exp(-JKD2/2)-exp(-JKD1/2)),2)+1)/2; 271end; 272 273 274d=JKLD; 275tmp1 = d(1-min(T(:))+T(CC.TI,:),cl==CL(1)); 276[sum0,n0,ssq0] = sumskipnan(tmp1(:)); 277tmp2 = d(1-min(T(:))+T(CC.TI,:),cl==CL(2)); 278[sum1,n1,ssq1] = sumskipnan(tmp2(:)); 279CC.LDA.AUC = auc(tmp1,tmp2); 280CC.LDA.ERR(1,1) = mean(sign([tmp1(:)]))/2+1/2; 281CC.LDA.ERR(1,2) = mean(sign([tmp2(:)]))/2+1/2; 282CC.LDA.ERR(2,1) = mean(sign([mean(tmp1,1)']))/2+1/2; 283CC.LDA.ERR(2,2) = mean(sign([mean(tmp2,1)']))/2+1/2; 284s0 = (ssq0-sum0.*sum0./n0)./(n0-1); 285s1 = (ssq1-sum1.*sum1./n1)./(n1-1); 286s = (ssq0+ssq1-(sum0+sum1).*(sum0+sum1)./(n0+n1))./(n0+n1-1); 287SNR = 2*s./(s0+s1); % this is SNR+1 288CC.LDA.I = log2(SNR)/2; 289CC.LDA.SNR = SNR - 1; 290if 0, 291 clear tmp1 tmp2; 292 tmp1 = stat2(d(:,cl==CL(1)),2); 293 tmp2 = stat2(d(:,cl==CL(2)),2); 294 CC.LDA.TSD=stat2res(tmp1,tmp2); 295 CC.LDA.TSD.ERR=1/2-mean(sign([-d(:,cl==CL(1)),d(:,cl==CL(2))]),2)/2; 296elseif bitand(SWITCH,1), 297 CC.LDA.TSD=bci3eval(d(:,cl==CL(1)),d(:,cl==CL(2)),2); 298end; 299 300d = JKD1 - JKD2; 301tmp1 = d(1-min(T(:))+T(CC.TI,:),cl==CL(1)); 302[sum0,n0,ssq0] = sumskipnan(tmp1(:)); 303tmp2 = d(1-min(T(:))+T(CC.TI,:),cl==CL(2)); 304[sum1,n1,ssq1] = sumskipnan(tmp2(:)); 305CC.MDA.AUC = auc(tmp1,tmp2); 306CC.MDA.ERR(1,1) = mean(sign([tmp1(:)]))/2+1/2; 307CC.MDA.ERR(1,2) = mean(sign([tmp2(:)]))/2+1/2; 308CC.MDA.ERR(2,1) = mean(sign([mean(tmp1,1)']))/2+1/2; 309CC.MDA.ERR(2,2) = mean(sign([mean(tmp2,1)']))/2+1/2; 310s0 = (ssq0-sum0.*sum0./n0)./(n0-1); 311s1 = (ssq1-sum1.*sum1./n1)./(n1-1); 312s = (ssq0+ssq1-(sum0+sum1).*(sum0+sum1)./(n0+n1))./(n0+n1-1); 313SNR = 2*s./(s0+s1); % this is SNR+1 314CC.MDA.I = log2(SNR)/2; 315CC.MDA.SNR = SNR - 1; 316if 0, 317 clear tmp1 tmp2; 318 tmp1 = stat2(d(:,cl==CL(1)),2); 319 tmp2 = stat2(d(:,cl==CL(2)),2); 320 CC.MDA.TSD=stat2res(tmp1,tmp2); 321 CC.MDA.TSD.ERR=1/2-mean(sign([-d(:,cl==CL(1)),d(:,cl==CL(2))]),2)/2; 322elseif bitand(SWITCH,1), 323 CC.MDA.TSD=bci3eval(d(:,cl==CL(1)),d(:,cl==CL(2)),2); 324end; 325 326d = sqrt(JKD1) - sqrt(JKD2); 327tmp1 = d(1-min(T(:))+T(CC.TI,:),cl==CL(1)); 328[sum0,n0,ssq0] = sumskipnan(tmp1(:)); 329tmp2 = d(1-min(T(:))+T(CC.TI,:),cl==CL(2)); 330[sum1,n1,ssq1] = sumskipnan(tmp2(:)); 331CC.MD2.AUC = auc(tmp1,tmp2); 332CC.MD2.ERR(1,1) = mean(sign([tmp1(:)]))/2+1/2; 333CC.MD2.ERR(1,2) = mean(sign([tmp2(:)]))/2+1/2; 334CC.MD2.ERR(2,1) = mean(sign([mean(tmp1,1)']))/2+1/2; 335CC.MD2.ERR(2,2) = mean(sign([mean(tmp2,1)']))/2+1/2; 336s0 = (ssq0-sum0.*sum0./n0)./(n0-1); 337s1 = (ssq1-sum1.*sum1./n1)./(n1-1); 338s = (ssq0+ssq1-(sum0+sum1).*(sum0+sum1)./(n0+n1))./(n0+n1-1); 339SNR = 2*s./(s0+s1); % this is SNR+1 340CC.MD2.I = log2(SNR)/2; 341CC.MD2.SNR = SNR - 1; 342if 0, 343 clear tmp1 tmp2; 344 tmp1 = stat2(d(:,cl==CL(1)),2); 345 tmp2 = stat2(d(:,cl==CL(2)),2); 346 CC.MD2.TSD=stat2res(tmp1,tmp2); 347 CC.MD2.TSD.ERR=1/2-mean(sign([-d(:,cl==CL(1)),d(:,cl==CL(2))]),2)/2; 348elseif bitand(SWITCH,1), 349 CC.MD2.TSD=bci3eval(d(:,cl==CL(1)),d(:,cl==CL(2)),2); 350end; 351 352%%% 353if any(isnan(cl)), 354 t = perm(TRIG(isnan(cl)),T(CC.TI,:)); 355 t = t(t<=size(D,1)); 356 tmp= rs(D(t(:),:),size(T,2),1); 357 [CC.OUT.LDA] = ldbc(CC.MD,tmp); 358 CC.OUT.LDAcl = CL((CC.OUT.LDA>0)+1); 359 [CC.OUT.MDA] = mdbc(CC.MD,tmp); 360 [tmp,ix] = min(CC.OUT.MDA,[],2); 361 tmp = isnan(tmp); 362 ix(tmp) = NaN; % invalid output, not classified 363 ix(~tmp) = CL(ix(~tmp)); 364 CC.OUT.MDAcl = ix; 365end; 366 367return; 368 369d = JKD3 - JKD4; 370tmp1 = d(1-min(T(:))+T(CC.TI,:),cl==CL(1)); 371[sum0,n0,ssq0] = sumskipnan(tmp1(:)); 372tmp2 = d(1-min(T(:))+T(CC.TI,:),cl==CL(2)); 373[sum1,n1,ssq1] = sumskipnan(tmp2(:)); 374CC.MLL.AUC = auc(tmp1,tmp2); 375CC.MLL.ERR(1,1) = mean(sign([tmp1(:)]))/2+1/2; 376CC.MLL.ERR(1,2) = mean(sign([tmp2(:)]))/2+1/2; 377CC.MLL.ERR(2,1) = mean(sign([mean(tmp1,1)']))/2+1/2; 378CC.MLL.ERR(2,2) = mean(sign([mean(tmp2,1)']))/2+1/2; 379s0 = (ssq0-sum0.*sum0./n0)./(n0-1); 380s1 = (ssq1-sum1.*sum1./n1)./(n1-1); 381s = (ssq0+ssq1-(sum0+sum1).*(sum0+sum1)./(n0+n1))./(n0+n1-1); 382SNR = 2*s./(s0+s1); % this is SNR+1 383CC.MLL.I = log2(SNR)/2; 384CC.MLL.SNR = SNR - 1; 385if 0, 386 clear tmp1 tmp2; 387 tmp1 = stat2(d(:,cl==CL(1)),2); 388 tmp2 = stat2(d(:,cl==CL(2)),2); 389 CC.MLL.TSD=stat2res(tmp1,tmp2); 390 CC.MLL.TSD.ERR=mean(sign([-d(:,cl==CL(1)),d(:,cl==CL(2))]),2)/2+1/2; 391elseif bitand(SWITCH,1), 392 CC.MLL.TSD=bci3eval(d(:,cl==CL(1)),d(:,cl==CL(2)),2); 393end; 394 395d = exp(-JKD1/2)-exp(-JKD2/2); 396tmp1 = d(1-min(T(:))+T(CC.TI,:),cl==CL(1)); 397[sum0,n0,ssq0] = sumskipnan(tmp1(:)); 398tmp2 = d(1-min(T(:))+T(CC.TI,:),cl==CL(2)); 399[sum1,n1,ssq1] = sumskipnan(tmp2(:)); 400CC.GRB.AUC = auc(tmp1,tmp2); 401CC.GRB.ERR(1,1) = mean(sign([tmp1(:)]))/2+1/2; 402CC.GRB.ERR(1,2) = mean(sign([tmp2(:)]))/2+1/2; 403CC.GRB.ERR(2,1) = mean(sign([mean(tmp1,1)']))/2+1/2; 404CC.GRB.ERR(2,2) = mean(sign([mean(tmp2,1)']))/2+1/2; 405s0 = (ssq0-sum0.*sum0./n0)./(n0-1); 406s1 = (ssq1-sum1.*sum1./n1)./(n1-1); 407s = (ssq0+ssq1-(sum0+sum1).*(sum0+sum1)./(n0+n1))./(n0+n1-1); 408SNR = 2*s./(s0+s1); % this is SNR+1 409CC.GRB.I = log2(SNR)/2; 410CC.GRB.SNR = SNR - 1; 411if 0, 412 clear tmp1 tmp2; 413 tmp1 = stat2(d(:,cl==CL(1)),2); 414 tmp2 = stat2(d(:,cl==CL(2)),2); 415 CC.GRB.TSD=stat2res(tmp1,tmp2); 416 CC.GRB.TSD.ERR=1/2-mean(sign([-d(:,cl==CL(1)),d(:,cl==CL(2))]),2)/2; 417elseif bitand(SWITCH,1), 418 CC.GRB.TSD=bci3eval(d(:,cl==CL(1)),d(:,cl==CL(2)),2); 419end; 420 421