lrds > BcicompIIIiva.m

BcicompIIIiva

PURPOSE ^

BcicompIIIiva.m - main script file that applies the method to BCI

SYNOPSIS ^

This is a script file.

DESCRIPTION ^

 BcicompIIIiva.m - main script file that applies the method to BCI
 competition III dataset IVa

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 % BcicompIIIiva.m - main script file that applies the method to BCI
0002 % competition III dataset IVa
0003 
0004 file   = 'data_set_IVa_%s.mat';
0005 file_t = 'data_set_IVa_%s_truth.mat';
0006 
0007 
0008 subjects = {'aa','al','av','aw','ay'};
0009 
0010 opt.ival= [500 3500];
0011 opt.filtOrder= 5;
0012 opt.band = [7 30];
0013 opt.logm = 0;
0014 
0015 %% Reduced set of 49 channels
0016 opt.chanind = [14, 15, 16, 17, 18, 19, 20, 21, 22, 33, 34, 35, 36, 37, 38, ...
0017                39, 50, 51, 52, 53, 54, 55, 56, 57, 58, 68, 69, 70, 71, 72, ...
0018                73, 74, 75, 76, 87, 88, 89, 90, 91, 92, 93, 94, 95, 104, 106,...
0019               108, 112, 113, 114];
0020 
0021 %% Load precomputed filter coefficients (7-30Hz Butterworth filter)
0022 butter = load('butter730.mat');
0023 
0024 lambda = exp(linspace(log(0.01), log(100), 20));
0025 
0026 memo = repmat(struct('lambda',[],'cls',[],'out',[],'loss',[]),...
0027               [length(subjects), length(lambda)]);
0028 
0029 for jj=1:length(subjects)
0030   fprintf('Subject: %s\n', subjects{jj});
0031   
0032   %% Load a dataset and preprocess
0033   load(sprintf(file, subjects{jj}));
0034 
0035   %% Select channels and covert cnt into double
0036   cnt  = 0.1*double(cnt(:,opt.chanind));
0037   clab = nfo.clab(opt.chanind);
0038   C = length(clab);
0039   
0040   %% Apply a band-pass filter
0041   cnt = filter(butter.b, butter.a, cnt);
0042   
0043   %% Cut EEG into tirals
0044   xepo = cutoutTrials(cnt, mrk.pos, opt.ival, nfo.fs);
0045   X = covariance(xepo);
0046   Y = (mrk.y-1.5)*2;  % convert {1,2} -> {-1, 1}
0047   
0048   %% Find indices of training and test set
0049   Itrain = find(~isnan(Y));
0050   Itest  = find(isnan(Y));
0051 
0052   %% Whiten the training data
0053   Xtr = X(:,:,Itrain);
0054   Ytr = Y(Itrain);
0055   [Xtr, Ww] = whiten(Xtr);
0056 
0057   if opt.logm
0058     Xtr = logmatrix(Xtr);
0059   end
0060     
0061   %% Train the classifier for various values of lambda
0062   for ii=1:length(lambda)
0063     [W, bias] = lrds_dual(Xtr, Ytr, lambda(ii));
0064     memo(jj,ii).lambda = lambda(ii);
0065     if ~opt.logm
0066       memo(jj,ii).cls    = struct('W',W,'bias',bias,'Ww',Ww);
0067     else
0068       memo(jj,ii).cls    = struct('W',W,'bias',bias,'Ww',eye(C));
0069     end
0070   end
0071   
0072   
0073   %% Load the true label of the test set
0074   load(sprintf(file_t, subjects{jj}));
0075   true_y = (true_y(Itest)-1.5)*2;
0076 
0077   Xte = X(:,:,Itest);
0078   
0079   if opt.logm
0080     Xte = logmatrix(matmultcv(Xte, Ww));
0081   end
0082     
0083   %% Apply the classifier
0084   fprintf('Subject: %s\n', subjects{jj});
0085   fprintf('lambda\t loss\n------------------------------------\n');
0086   for ii=1:length(lambda)
0087     memo(jj,ii).out  = apply_lrds(Xte, memo(jj,ii).cls);
0088     memo(jj,ii).loss = loss_0_1(true_y, memo(jj,ii).out);
0089     
0090     fprintf('%g\t%g\n', lambda(ii), memo(jj,ii).loss);
0091   end
0092 end
0093 
0094 loss=cell2mat(getfieldarray(memo,'loss'));
0095 
0096 figure, plot(log(lambda), 100*(1-loss)', 'linewidth',2)
0097 set(gca,'fontsize',20)
0098 set(gca,'xtick',log(0.01):log(10):log(100))
0099 set(gca,'xticklabel', {'0.01', '0.1', '1.0', '10', '100'})
0100 grid on;
0101 hold on;
0102 plot(log(lambda), 100*(1-mean(loss)), 'color',[.7 .7 .7], 'linewidth', 2);
0103 leg = [subjects {'average'}];
0104 legend(leg);
0105 xlabel('Regularization constant \lambda')
0106 ylabel('Classification error')

Generated on Sat 26-Apr-2008 15:48:23 by m2html © 2003