0001
0002
0003
0004 file = 'data_set_IVa_%s.mat';
0005 file_t = 'data_set_IVa_%s_truth.mat';
0006
0007
0008 subjects = {'aa','al','av','aw','ay'};
0009
0010 opt.ival= [500 3500];
0011 opt.filtOrder= 5;
0012 opt.band = [7 30];
0013 opt.logm = 0;
0014
0015
0016 opt.chanind = [14, 15, 16, 17, 18, 19, 20, 21, 22, 33, 34, 35, 36, 37, 38, ...
0017 39, 50, 51, 52, 53, 54, 55, 56, 57, 58, 68, 69, 70, 71, 72, ...
0018 73, 74, 75, 76, 87, 88, 89, 90, 91, 92, 93, 94, 95, 104, 106,...
0019 108, 112, 113, 114];
0020
0021
0022 butter = load('butter730.mat');
0023
0024 lambda = exp(linspace(log(0.01), log(100), 20));
0025
0026 memo = repmat(struct('lambda',[],'cls',[],'out',[],'loss',[]),...
0027 [length(subjects), length(lambda)]);
0028
0029 for jj=1:length(subjects)
0030 fprintf('Subject: %s\n', subjects{jj});
0031
0032
0033 load(sprintf(file, subjects{jj}));
0034
0035
0036 cnt = 0.1*double(cnt(:,opt.chanind));
0037 clab = nfo.clab(opt.chanind);
0038 C = length(clab);
0039
0040
0041 cnt = filter(butter.b, butter.a, cnt);
0042
0043
0044 xepo = cutoutTrials(cnt, mrk.pos, opt.ival, nfo.fs);
0045 X = covariance(xepo);
0046 Y = (mrk.y-1.5)*2;
0047
0048
0049 Itrain = find(~isnan(Y));
0050 Itest = find(isnan(Y));
0051
0052
0053 Xtr = X(:,:,Itrain);
0054 Ytr = Y(Itrain);
0055 [Xtr, Ww] = whiten(Xtr);
0056
0057 if opt.logm
0058 Xtr = logmatrix(Xtr);
0059 end
0060
0061
0062 for ii=1:length(lambda)
0063 [W, bias] = lrds_dual(Xtr, Ytr, lambda(ii));
0064 memo(jj,ii).lambda = lambda(ii);
0065 if ~opt.logm
0066 memo(jj,ii).cls = struct('W',W,'bias',bias,'Ww',Ww);
0067 else
0068 memo(jj,ii).cls = struct('W',W,'bias',bias,'Ww',eye(C));
0069 end
0070 end
0071
0072
0073
0074 load(sprintf(file_t, subjects{jj}));
0075 true_y = (true_y(Itest)-1.5)*2;
0076
0077 Xte = X(:,:,Itest);
0078
0079 if opt.logm
0080 Xte = logmatrix(matmultcv(Xte, Ww));
0081 end
0082
0083
0084 fprintf('Subject: %s\n', subjects{jj});
0085 fprintf('lambda\t loss\n------------------------------------\n');
0086 for ii=1:length(lambda)
0087 memo(jj,ii).out = apply_lrds(Xte, memo(jj,ii).cls);
0088 memo(jj,ii).loss = loss_0_1(true_y, memo(jj,ii).out);
0089
0090 fprintf('%g\t%g\n', lambda(ii), memo(jj,ii).loss);
0091 end
0092 end
0093
0094 loss=cell2mat(getfieldarray(memo,'loss'));
0095
0096 figure, plot(log(lambda), 100*(1-loss)', 'linewidth',2)
0097 set(gca,'fontsize',20)
0098 set(gca,'xtick',log(0.01):log(10):log(100))
0099 set(gca,'xticklabel', {'0.01', '0.1', '1.0', '10', '100'})
0100 grid on;
0101 hold on;
0102 plot(log(lambda), 100*(1-mean(loss)), 'color',[.7 .7 .7], 'linewidth', 2);
0103 leg = [subjects {'average'}];
0104 legend(leg);
0105 xlabel('Regularization constant \lambda')
0106 ylabel('Classification error')