1function [CC,Q,tsd,md]=findclassifier1(D,TRIG,cl,T,t0,SWITCH)
2% FINDCLASSIFIER1
3%
4% [CC,Q,TSD,MD]=findclassifier1(D,TRIG,Class,class_times,t_ref);
5%
6% D 	data, each row is one time point
7% TRIG	trigger time points
8% Class class information
9% class_times	classification times, combinations of times must be in one row
10% t_ref	reference time for Class 0 (optional)
11%
12% CC 	contains LDA and MD classifiers
13% Q  	is a list of classification quality for each time of 'class_times'
14% TSD 	returns the LDA classification
15% MD	returns the MD  classification
16%
17% [CC,Q,TSD,MD]=findclassifier(AR,find(trig>0.5)-257,~mod(1:80,2),reshape(1:14*128,16,14*8)');
18%
19%
20% Reference(s):
21% [1] Schlögl A., Neuper C. Pfurtscheller G.
22% 	Estimating the mutual information of an EEG-based Brain-Computer-Interface
23%  	Biomedizinische Technik 47(1-2): 3-8, 2002.
24% [2] A. Schlögl, C. Keinrath, R. Scherer, G. Pfurtscheller,
25%	Information transfer of an EEG-based Bran-computer interface.
26%	Proceedings of the 1st International IEEE EMBS Conference on Neural Engineering, Capri, Italy, Mar 20-22, 2003
27
28
29%   Copyright (C) 1999-2004 by Alois Schloegl <alois.schloegl@gmail.com>
30%	$Id$
31
32
33% This program is free software; you can redistribute it and/or
34% modify it under the terms of the GNU General Public License
35% as published by the Free Software Foundation; either version 2
36% of the  License, or (at your option) any later version.
37%
38% This program is distributed in the hope that it will be useful,
39% but WITHOUT ANY WARRANTY; without even the implied warranty of
40% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
41% GNU General Public License for more details.
42%
43% You should have received a copy of the GNU General Public License
44% along with this program; if not, write to the Free Software
45% Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
46
47
48warning('this function is obsolete and replaced by FINDCLASSIFIER');
49
50
51tsd=[];md=[];
52
53if nargin<6,
54        SWITCH=0;
55end;
56if nargin>4,
57        if isempty(t0),
58                t0=logical(ones(size(T,1),1));
59        end;
60end;
61tmp=cl;tmp(isnan(tmp))=0;
62if any(rem(tmp,1) & ~isnan(cl)),
63        fprintf(2,'Error %s: class information is not integer\n',mfilename);
64        return;
65end;
66if length(TRIG)~=length(cl);
67        fprintf(2,'number of Triggers do not match class information');
68end;
69
70CL = unique(cl(~isnan(cl)));
71CL = sort(CL);
72TRIG = TRIG(:);
73if ~all(D(:,1)==1)
74        %        D1=[ones(size(D,1)-1,1),diff(D)];
75        D =[ones(size(D,1),1),D];
76        %else
77        %	D1=[ones(size(D,1)-1,1),diff(D(:,2:end))];
78end;
79
80% add sufficient NaNs at the beginning and the end
81tmp = min(TRIG)+min(min(T))-1;
82if tmp<0,
83        TRIG = TRIG - tmp;
84        D = [repmat(nan,[-tmp,size(D,2)]);D];
85end;
86tmp = max(TRIG)+max(max(T))-size(D,1);
87if tmp>0,
88        D = [D;repmat(nan,[tmp,size(D,2)])];
89end;
90
91% estimate classification result for all time segments in T - without crossvalidation
92CMX = zeros([size(T,1),length(CL)*[1,1]]);
93for k = 1:size(T,1),
94        cmx = zeros(length(CL));
95        for l = 1:length(CL),
96                t = perm(TRIG(cl==CL(l)),T(k,:));
97                %t = t(t<=size(D,1));
98                [C{k,l},tmp] = covm(D(t(:),:),'M');
99        end;
100        %[Q(k),d{k}] = qcmahal({C0r,C{k,:}});
101        [CC.QC(k),d{k}] = qcmahal({C{k,:}});
102        lnQ(k) = mean(log(d{k}(~eye(length(d{k})))));
103        for l = 1:length(CL),
104                t = perm(TRIG(cl==CL(l)),T(k,:));
105                %t = t(t<=size(D,1));
106                [tmp] = mdbc({C{k,:}},D(t(:),:));
107                [tmp,ix] = min(tmp,[],2);
108                tmp = isnan(tmp);
109                ix(tmp) = NaN; %NC(1)+1;	% not classified; any value but not 1:length(MD)
110                ix(~tmp) = CL(ix(~tmp));
111                tmp = histo3([ix;CL(:)]);
112                cmx(tmp.X,l) = tmp.H-1;
113        end;
114        CMX(k,:,:)  = cmx;
115        CC.KAPPA(k) = kappa(cmx);
116        CC.ACC(k)   = sum(diag(cmx))/sum(cmx(:));
117end;
118% identify best classification time
119if nargin>4,
120        tmp = CC.QC;
121        tmp(~t0) = 0;
122        [maxQ,CC.TI] = max(tmp); %d{K},
123else
124        [maxQ,CC.TI] = max(CC.QC); %d{K},
125end;
126%CC.TI = K;
127
128% build MD classifier
129CC.MD  = {C{CC.TI,:}};
130CC.IR  = mdbc({C{CC.TI,:}});
131CC.D   = d{CC.TI};
132CC.Q   = CC.QC(CC.TI);
133CC.CMX = squeeze(CMX(CC.TI,:,:));
134
135m1=decovm(CC.MD{1});
136m2=decovm(CC.MD{2});
137tmp=mdbc(CC.MD,[1,m1;1,m2]);
138CC.scale=[1,1]*max(abs(tmp(:)));  % element 1
139
140[maxQ,CC.lnTI] = max(lnQ); %d{K},
141CC.DistMXln = d{CC.lnTI};
142CC.MDln = {C{CC.lnTI,:}};
143
144% alternative classifier using two different time segments.
145if 1,
146        [Q,d] = qcmahal({C{:}}');
147        CC.T2.D = d;
148        [ix,iy] = find(d==max(d(:)));
149        ix=mod(ix(1)-1,size(C,1))+1;
150        iy=mod(iy(1)-1,size(C,1))+1;
151        CC.T2.TI = [ix(1),iy(1)];
152        CC.T2.MD = {C{ix(1),1},C{iy(1),2}};
153end;
154
155% build LDA classifier
156if length(CL)==2,
157        % LDA
158        C0 = zeros(size(C{CC.TI,1}));
159        for l=1:length(CL);
160                [M{l},sd,COV,xc,N,R2] = decovm(C{CC.TI,l});
161                C0 = C0 + C{CC.TI,l};
162        end;
163        [M0,sd,COV0,xc,N,R2] = decovm(C0);
164        w     = COV0\(M{1}'-M{2}');
165        w0    = M0*w;
166        %CC.LDA.b = w0;
167        %CC.LDA.w = -w;
168        CC.lda = [w0; -w];
169
170        % MD
171        tsd = ldbc(CC.MD,D);
172end;
173md = mdbc(CC.MD,D);
174lld= llbc(CC.MD,D);
175
176% bias correction not used anymore.
177CC.BIAS.LDA=0;%mean(tsd);
178CC.BIAS.MDA=0;%mean(diff(-md,[],2));
179CC.BIAS.GRB=0;%mean(diff(-exp(-md/2),[],2));
180
181%[tmp,IX] = min(md,[],2);
182
183%% cross-validation with jackknife (leave one trial out)
184nc  = max(max(T))-min(min(T))+1;
185
186JKD   = repmat(nan,[nc,length(CL),length(TRIG)]);
187JKD1  = repmat(nan,[nc,length(TRIG)]);
188JKD2  = repmat(nan,[nc,length(TRIG)]);
189JKLD  = repmat(nan,[nc,length(TRIG)]);
190for l = find(~isnan(cl(:)'));1:length(cl);
191        c = find(cl(l)==CL);
192        t = TRIG(l)+T(CC.TI,:);
193        %t = t(t<=size(D,1));
194        [tmp,tmpn] = covm(D(t(:),:),'M');
195
196        cc 	= CC.MD;
197        cc{c}   = CC.MD{c}-tmp;
198
199        %t = TRIG(l)+(1:nc);
200        %t = t(t<=size(D,1));
201        t = TRIG(l)+(min(min(T)):max(max(T)));
202
203        [d,ix] = llbc(cc,D(t,:));
204        if length(CL)==2,
205                JKD3(:,l)=d(:,1);
206                JKD4(:,l)=d(:,2);
207        end;
208
209        d = mdbc(cc,D(t,:));
210        JKD(:,:,l) = d;
211        [tmp,MDIX(:,l)] = min(d,[],2);
212
213        if length(CL)==2,
214                JKD1(:,l) = d(:,1);
215                JKD2(:,l) = d(:,2);
216
217                LDA(:,l) = ldbc(cc);
218                JKLD(:,l) = D(t,:)*LDA(:,l);
219        end;
220end;
221[CC.ldaC0,NN] = covm(LDA','D0');
222%CC.ldaC0=CC.ldaC0./NN*min(0,sum(~isnan(CL))-1);
223% since NN==min(0,sum(~isnan(CL))-1), no need to rescale
224
225% Concordance matrix with cross-validation
226CC.mmx= zeros([size(MDIX,1),length(CL)^2]);
227CC.I0 = zeros([size(MDIX,1),length(CL)]);
228CC.I  = zeros([size(MDIX,1),1]);
229tmp   = zeros([size(MDIX,1),length(CL)]);
230for k = 1:length(CL),
231        jkd = squeeze(JKD(:,k,:));
232        o = bci3eval(jkd(:,cl~=k),jkd(:,cl==k),2);
233
234        CC.TSD{k}  = o;
235        CC.I0(:,k) = log2(2*var(jkd,[],2)./(var(jkd(:,cl==k),[],2) + var(jkd(:,cl~=k),[],2)))/2;
236
237        [sum0,n0,ssq0]=sumskipnan(jkd(:,cl==k),2);
238        [sum1,n1,ssq1]=sumskipnan(jkd(:,cl~=k),2);
239        s0  = (ssq0-sum0.*sum0./n0)./(n0-1);
240        s1  = (ssq1-sum1.*sum1./n1)./(n1-1);
241        s   = (ssq0+ssq1-(sum0+sum1).*(sum0+sum1)./(n0+n1))./(n0+n1-1);
242        SNR = 2*s./(s0+s1); % this is SNR+1
243        CC.I1(:,k) = log2(SNR)/2;
244
245        for l = 1:length(CL),
246                tmp(:,l) = sum(MDIX(:,cl==CL(k))==CL(l),2);
247                if CL(k) == CL(l),
248                        acc = tmp(:,l);
249                end;
250        end;
251        CC.mmx(:,(1-length(CL):0)+k*length(CL)) = tmp;
252        CC.acc(:,k) = acc./sum(tmp,2);
253end;
254CC.CMX00 = reshape(sum(CC.mmx(T(CC.TI,:),:),1),[1,1]*length(CL))/size(T,2);
255CC.I = sum(CC.I0,2);
256CC.ACC00 = sum(CC.mmx(:,1:length(CL)+1:end),2)/sum(~isnan(cl));
257CC.KAP00 = zeros(size(MDIX,1),1);
258for k = 1:size(MDIX,1),
259        CC.KAP00(k) = kappa(reshape(CC.mmx(k,:),[1,1]*length(CL)));
260end;
261
262if length(CL) > 2,
263        return;
264end;
265
266
267if bitand(SWITCH,1),
268        CC.LDA.ERR00 = (mean(sign(JKLD),2)+1)/2;
269        CC.MDA.ERR00 = (mean(sign(JKD1-JKD2),2)+1)/2;
270        CC.GRB.ERR00 = (mean(sign(exp(-JKD2/2)-exp(-JKD1/2)),2)+1)/2;
271end;
272
273
274d=JKLD;
275tmp1 = d(1-min(T(:))+T(CC.TI,:),cl==CL(1));
276[sum0,n0,ssq0] = sumskipnan(tmp1(:));
277tmp2 = d(1-min(T(:))+T(CC.TI,:),cl==CL(2));
278[sum1,n1,ssq1] = sumskipnan(tmp2(:));
279CC.LDA.AUC      = auc(tmp1,tmp2);
280CC.LDA.ERR(1,1) = mean(sign([tmp1(:)]))/2+1/2;
281CC.LDA.ERR(1,2) = mean(sign([tmp2(:)]))/2+1/2;
282CC.LDA.ERR(2,1) = mean(sign([mean(tmp1,1)']))/2+1/2;
283CC.LDA.ERR(2,2) = mean(sign([mean(tmp2,1)']))/2+1/2;
284s0  = (ssq0-sum0.*sum0./n0)./(n0-1);
285s1  = (ssq1-sum1.*sum1./n1)./(n1-1);
286s   = (ssq0+ssq1-(sum0+sum1).*(sum0+sum1)./(n0+n1))./(n0+n1-1);
287SNR = 2*s./(s0+s1); % this is SNR+1
288CC.LDA.I   = log2(SNR)/2;
289CC.LDA.SNR = SNR - 1;
290if 0,
291        clear tmp1 tmp2;
292        tmp1 = stat2(d(:,cl==CL(1)),2);
293        tmp2 = stat2(d(:,cl==CL(2)),2);
294        CC.LDA.TSD=stat2res(tmp1,tmp2);
295        CC.LDA.TSD.ERR=1/2-mean(sign([-d(:,cl==CL(1)),d(:,cl==CL(2))]),2)/2;
296elseif bitand(SWITCH,1),
297        CC.LDA.TSD=bci3eval(d(:,cl==CL(1)),d(:,cl==CL(2)),2);
298end;
299
300d = JKD1 - JKD2;
301tmp1 = d(1-min(T(:))+T(CC.TI,:),cl==CL(1));
302[sum0,n0,ssq0] = sumskipnan(tmp1(:));
303tmp2 = d(1-min(T(:))+T(CC.TI,:),cl==CL(2));
304[sum1,n1,ssq1] = sumskipnan(tmp2(:));
305CC.MDA.AUC      = auc(tmp1,tmp2);
306CC.MDA.ERR(1,1) = mean(sign([tmp1(:)]))/2+1/2;
307CC.MDA.ERR(1,2) = mean(sign([tmp2(:)]))/2+1/2;
308CC.MDA.ERR(2,1) = mean(sign([mean(tmp1,1)']))/2+1/2;
309CC.MDA.ERR(2,2) = mean(sign([mean(tmp2,1)']))/2+1/2;
310s0  = (ssq0-sum0.*sum0./n0)./(n0-1);
311s1  = (ssq1-sum1.*sum1./n1)./(n1-1);
312s   = (ssq0+ssq1-(sum0+sum1).*(sum0+sum1)./(n0+n1))./(n0+n1-1);
313SNR = 2*s./(s0+s1); % this is SNR+1
314CC.MDA.I   = log2(SNR)/2;
315CC.MDA.SNR = SNR - 1;
316if 0,
317        clear tmp1 tmp2;
318        tmp1 = stat2(d(:,cl==CL(1)),2);
319        tmp2 = stat2(d(:,cl==CL(2)),2);
320        CC.MDA.TSD=stat2res(tmp1,tmp2);
321        CC.MDA.TSD.ERR=1/2-mean(sign([-d(:,cl==CL(1)),d(:,cl==CL(2))]),2)/2;
322elseif bitand(SWITCH,1),
323        CC.MDA.TSD=bci3eval(d(:,cl==CL(1)),d(:,cl==CL(2)),2);
324end;
325
326d = sqrt(JKD1) - sqrt(JKD2);
327tmp1 = d(1-min(T(:))+T(CC.TI,:),cl==CL(1));
328[sum0,n0,ssq0] = sumskipnan(tmp1(:));
329tmp2 = d(1-min(T(:))+T(CC.TI,:),cl==CL(2));
330[sum1,n1,ssq1] = sumskipnan(tmp2(:));
331CC.MD2.AUC      = auc(tmp1,tmp2);
332CC.MD2.ERR(1,1) = mean(sign([tmp1(:)]))/2+1/2;
333CC.MD2.ERR(1,2) = mean(sign([tmp2(:)]))/2+1/2;
334CC.MD2.ERR(2,1) = mean(sign([mean(tmp1,1)']))/2+1/2;
335CC.MD2.ERR(2,2) = mean(sign([mean(tmp2,1)']))/2+1/2;
336s0  = (ssq0-sum0.*sum0./n0)./(n0-1);
337s1  = (ssq1-sum1.*sum1./n1)./(n1-1);
338s   = (ssq0+ssq1-(sum0+sum1).*(sum0+sum1)./(n0+n1))./(n0+n1-1);
339SNR = 2*s./(s0+s1); % this is SNR+1
340CC.MD2.I   = log2(SNR)/2;
341CC.MD2.SNR = SNR - 1;
342if 0,
343        clear tmp1 tmp2;
344        tmp1 = stat2(d(:,cl==CL(1)),2);
345        tmp2 = stat2(d(:,cl==CL(2)),2);
346        CC.MD2.TSD=stat2res(tmp1,tmp2);
347        CC.MD2.TSD.ERR=1/2-mean(sign([-d(:,cl==CL(1)),d(:,cl==CL(2))]),2)/2;
348elseif bitand(SWITCH,1),
349        CC.MD2.TSD=bci3eval(d(:,cl==CL(1)),d(:,cl==CL(2)),2);
350end;
351
352%%%
353if any(isnan(cl)),
354        t = perm(TRIG(isnan(cl)),T(CC.TI,:));
355        t = t(t<=size(D,1));
356        tmp= rs(D(t(:),:),size(T,2),1);
357        [CC.OUT.LDA] = ldbc(CC.MD,tmp);
358        CC.OUT.LDAcl = CL((CC.OUT.LDA>0)+1);
359        [CC.OUT.MDA] = mdbc(CC.MD,tmp);
360        [tmp,ix] = min(CC.OUT.MDA,[],2);
361        tmp = isnan(tmp);
362        ix(tmp) = NaN;   % invalid output, not classified
363        ix(~tmp) = CL(ix(~tmp));
364        CC.OUT.MDAcl = ix;
365end;
366
367return;
368
369d = JKD3 - JKD4;
370tmp1 = d(1-min(T(:))+T(CC.TI,:),cl==CL(1));
371[sum0,n0,ssq0] = sumskipnan(tmp1(:));
372tmp2 = d(1-min(T(:))+T(CC.TI,:),cl==CL(2));
373[sum1,n1,ssq1]  = sumskipnan(tmp2(:));
374CC.MLL.AUC      = auc(tmp1,tmp2);
375CC.MLL.ERR(1,1) = mean(sign([tmp1(:)]))/2+1/2;
376CC.MLL.ERR(1,2) = mean(sign([tmp2(:)]))/2+1/2;
377CC.MLL.ERR(2,1) = mean(sign([mean(tmp1,1)']))/2+1/2;
378CC.MLL.ERR(2,2) = mean(sign([mean(tmp2,1)']))/2+1/2;
379s0  = (ssq0-sum0.*sum0./n0)./(n0-1);
380s1  = (ssq1-sum1.*sum1./n1)./(n1-1);
381s   = (ssq0+ssq1-(sum0+sum1).*(sum0+sum1)./(n0+n1))./(n0+n1-1);
382SNR = 2*s./(s0+s1); % this is SNR+1
383CC.MLL.I   = log2(SNR)/2;
384CC.MLL.SNR = SNR - 1;
385if 0,
386	clear tmp1 tmp2;
387        tmp1 = stat2(d(:,cl==CL(1)),2);
388        tmp2 = stat2(d(:,cl==CL(2)),2);
389        CC.MLL.TSD=stat2res(tmp1,tmp2);
390        CC.MLL.TSD.ERR=mean(sign([-d(:,cl==CL(1)),d(:,cl==CL(2))]),2)/2+1/2;
391elseif bitand(SWITCH,1),
392        CC.MLL.TSD=bci3eval(d(:,cl==CL(1)),d(:,cl==CL(2)),2);
393end;
394
395d = exp(-JKD1/2)-exp(-JKD2/2);
396tmp1 = d(1-min(T(:))+T(CC.TI,:),cl==CL(1));
397[sum0,n0,ssq0] = sumskipnan(tmp1(:));
398tmp2 = d(1-min(T(:))+T(CC.TI,:),cl==CL(2));
399[sum1,n1,ssq1] = sumskipnan(tmp2(:));
400CC.GRB.AUC      = auc(tmp1,tmp2);
401CC.GRB.ERR(1,1) = mean(sign([tmp1(:)]))/2+1/2;
402CC.GRB.ERR(1,2) = mean(sign([tmp2(:)]))/2+1/2;
403CC.GRB.ERR(2,1) = mean(sign([mean(tmp1,1)']))/2+1/2;
404CC.GRB.ERR(2,2) = mean(sign([mean(tmp2,1)']))/2+1/2;
405s0  = (ssq0-sum0.*sum0./n0)./(n0-1);
406s1  = (ssq1-sum1.*sum1./n1)./(n1-1);
407s   = (ssq0+ssq1-(sum0+sum1).*(sum0+sum1)./(n0+n1))./(n0+n1-1);
408SNR = 2*s./(s0+s1); % this is SNR+1
409CC.GRB.I   = log2(SNR)/2;
410CC.GRB.SNR = SNR - 1;
411if 0,
412        clear tmp1 tmp2;
413        tmp1 = stat2(d(:,cl==CL(1)),2);
414        tmp2 = stat2(d(:,cl==CL(2)),2);
415        CC.GRB.TSD=stat2res(tmp1,tmp2);
416        CC.GRB.TSD.ERR=1/2-mean(sign([-d(:,cl==CL(1)),d(:,cl==CL(2))]),2)/2;
417elseif bitand(SWITCH,1),
418        CC.GRB.TSD=bci3eval(d(:,cl==CL(1)),d(:,cl==CL(2)),2);
419end;
420
421