lrds > lrds_dual.m

lrds_dual

PURPOSE ^

lrds_dual - Logistic regression with dual spectral regularization

SYNOPSIS ^

function [W, bias, z, status]=lrds_dual(X, Y, lambda, varargin)

DESCRIPTION ^

 lrds_dual - Logistic regression with dual spectral regularization
             for symmetric input matrix.

 Objective:
    Solves the regularized ERM problem:
        sum(loss(f(X_i),y_i)) + lambda*sum(svd(W))
    where
        f(X) = trace(W'*X) + bias

 Syntax:
  [W, bias, z, status]=lrds_dual(X, Y, lambda, <opt>)

 Inputs:
  X      : Input matrices.
           [C,C,n] array: Each X(:,:,i) assumed to be symmetric.
           [C^2,n] array: Each reshape(X(:,i), sqrt(size(X,1))) assumed to be symmetric.
  Y      : Binary lables. +1 or -1. [n,1] matrix. n is the number of samples.
  lambda : Regularization constant.
  <opt>  : Options.
    .tol     : absolute tolerance for duality gap (1e-6)
    .tolX    : relative tolerance for step size (1e-6)
    .tmul    : barrier parameter multiplier (20)
    .maxiter : maximum number of iteration (1000)
    .display : 'all'(4), 'every'(3), 'iter'(2), 'final'(1), or 'none'(0) ('iter')

 Outputs:
  W      : Weight matrix. [C,C] matrix.
  bias   : Bias term.
  z      : Classifier outputs. [n,1] matrix.
  status : Miscellaneous numbers.

 Reference:
 "Classifying Matrices with a Spectral Regularization",
 Ryota Tomioka and Kazuyuki Aihara,
 Proc. ICML2007, pp. 895-902, ACM Press; Oregon, USA, June, 2007.
 http://www.machinelearning.org/proceedings/icml2007/papers/401.pdf

 This software is distributed from:
 http://www.sat.t.u-tokyo.ac.jp/~ryotat/lrds/index.html

 Ryota Tomioka 2007.

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function [W, bias, z, status]=lrds_dual(X, Y, lambda, varargin)
0002 % lrds_dual - Logistic regression with dual spectral regularization
0003 %             for symmetric input matrix.
0004 %
0005 % Objective:
0006 %    Solves the regularized ERM problem:
0007 %        sum(loss(f(X_i),y_i)) + lambda*sum(svd(W))
0008 %    where
0009 %        f(X) = trace(W'*X) + bias
0010 %
0011 % Syntax:
0012 %  [W, bias, z, status]=lrds_dual(X, Y, lambda, <opt>)
0013 %
0014 % Inputs:
0015 %  X      : Input matrices.
0016 %           [C,C,n] array: Each X(:,:,i) assumed to be symmetric.
0017 %           [C^2,n] array: Each reshape(X(:,i), sqrt(size(X,1))) assumed to be symmetric.
0018 %  Y      : Binary lables. +1 or -1. [n,1] matrix. n is the number of samples.
0019 %  lambda : Regularization constant.
0020 %  <opt>  : Options.
0021 %    .tol     : absolute tolerance for duality gap (1e-6)
0022 %    .tolX    : relative tolerance for step size (1e-6)
0023 %    .tmul    : barrier parameter multiplier (20)
0024 %    .maxiter : maximum number of iteration (1000)
0025 %    .display : 'all'(4), 'every'(3), 'iter'(2), 'final'(1), or 'none'(0) ('iter')
0026 %
0027 % Outputs:
0028 %  W      : Weight matrix. [C,C] matrix.
0029 %  bias   : Bias term.
0030 %  z      : Classifier outputs. [n,1] matrix.
0031 %  status : Miscellaneous numbers.
0032 %
0033 % Reference:
0034 % "Classifying Matrices with a Spectral Regularization",
0035 % Ryota Tomioka and Kazuyuki Aihara,
0036 % Proc. ICML2007, pp. 895-902, ACM Press; Oregon, USA, June, 2007.
0037 % http://www.machinelearning.org/proceedings/icml2007/papers/401.pdf
0038 %
0039 % This software is distributed from:
0040 % http://www.sat.t.u-tokyo.ac.jp/~ryotat/lrds/index.html
0041 %
0042 % Ryota Tomioka 2007.
0043 
0044 try
0045   opt = propertylist2struct(varargin{:});
0046 catch
0047   if nargin>3
0048     opt = varargin{1};
0049   else
0050     opt = [];
0051   end
0052 end
0053 
0054 opt = setDefaults(opt, struct('tol', 1e-6, ...
0055                         'tolX', 1e-6, ...
0056                         'tmul', 20, ...
0057                         'maxiter', 1000, ...
0058                    'display', 'iter'));
0059 
0060 if ~isnumeric(opt.display)
0061   opt.display = find(strcmp(opt.display,...
0062                             {'none','final','iter','every','all'}))-1;
0063 end
0064 
0065 
0066 if ndims(X)==3 & size(X,1)==size(X,2)
0067   [C,Cd, n]=size(X);
0068   Xf=reshape(X, [C*C, n]);
0069 elseif ndims(X)==3 & size(X,1)
0070   [Cd,C, n]=size(X);
0071   Xf=shiftdim(X);
0072   X =reshape(X,[C,C,n]);
0073 else
0074   [CC,n]=size(X);
0075   C=sqrt(CC);
0076   Xf = X;
0077   X  = reshape(X,[C,C,n]);
0078 end
0079 
0080 if n~=length(Y)
0081   error('Sample size mismatch!');
0082 end
0083 
0084 Y=shiftdim(Y);
0085 
0086 if isfield(opt,'alpha')
0087   alpha = opt.alpha;
0088 else
0089   alpha = zeros(n,1);
0090   alpha(Y>0) = min(lambda*0.01,1)/sum(Y>0);
0091   alpha(Y<0) = min(lambda*0.01,1)/sum(Y<0);
0092 end
0093 
0094 cc = 0;
0095 display_line_search = opt.display==4;
0096 
0097 if isfield(opt,'t')
0098   t = opt.t;
0099 else
0100   t = 2*(C+n)/(n*log(2));
0101 end
0102 
0103 cc0 = 0;
0104 time0 = cputime;
0105 time00 = time0;
0106 
0107 while cc<opt.maxiter
0108   while cc<opt.maxiter
0109     cc = cc + 1;
0110     
0111     A = reshape(Xf*(alpha.*Y), [C,C]); A=(A+A')/2;
0112 
0113     R1 = chol(lambda*eye(C)-A); % X = R'*R (R is upper diagonal)
0114     R2 = chol(lambda*eye(C)+A);
0115 
0116     
0117     [loss, gl, hl]=lossDual(alpha);
0118 
0119 
0120     %% Calculate the gradient and Hessian associated to the
0121     %% barrier log det term
0122     SX1 = zeros(C*C,n);
0123     SX2 = zeros(C*C,n);
0124     gS  = zeros(n,1);
0125     for i=1:n
0126       D1 = R1'\(Y(i)*X(:,:,i))/R1;
0127       D2 = R2'\(Y(i)*X(:,:,i))/R2;
0128       SX1(:,i)=reshape(D1, [C*C,1]);
0129       SX2(:,i)=reshape(D2, [C*C,1]);
0130       gS(i) = sum(diag(D1-D2));
0131     end
0132     
0133     H1 = SX1'*SX1;
0134     H2 = SX2'*SX2;
0135 
0136     %% Gradient
0137     g = gl+((2*alpha-1)./(alpha.*(1-alpha)) + gS)/t;
0138 
0139     %% Hessian
0140     Hd = hl + (alpha.^(-2)+(1-alpha).^(-2))/t;
0141     Hr = (H1+H2)/t;
0142     H = diag(Hd)+Hr;
0143     
0144     HIg = H\g;
0145     HIy = H\Y;
0146 
0147     %% Lagrangian multp. assoc. to the equality constraint
0148     nu = - (Y'*HIg)/(Y'*HIy);
0149     
0150     %% Newton direction
0151     delta = -H\(g+Y*nu);
0152 
0153     alpha0 = alpha;
0154 
0155     %% Line search to determine the stepsize s
0156     Sd0 = reshape(Xf*(delta.*Y), [C,C]); Sd0=(Sd0+Sd0')/2;
0157     Sd1 = -eig(R1'\Sd0/R1);
0158     Sd2 = eig(R2'\Sd0/R2);
0159 
0160     [s, dloss] = lineSearch(alpha, delta, t, Sd1, Sd2, Y*nu, opt.tolX/max(abs(delta)./abs(alpha)), display_line_search);
0161 
0162     
0163     %% Update
0164     alpha = alpha0 + s*delta;
0165 
0166     A = reshape(Xf*(alpha.*Y), [C,C]); A=(A+A')/2;
0167     
0168    
0169     %% Weight matrix
0170 
0171     RR = chol(lambda^2*eye(C)-A*A');
0172 
0173     W = 2*(RR\((RR')\A))/t;
0174     W = (W+W')/2;
0175     trQ = 2*lambda*trace(((RR')\eye(C))/RR)/t;
0176 
0177     %% trQ = trace(Q1) + trace(Q2), where
0178     %% (lambda*eye(C) - A) * Q1 = 1/t *eye(C)
0179     %% (lambda*eye(C) + A) * Q2 = 1/t *eye(C)
0180     %%
0181     %% Note that trQ is not exactly sum(svd(W)) until convergence
0182 
0183     %% Lag. multp. assoc. to the box constraints 0<=alpha<=1
0184     beta1= 1./(t*alpha);
0185     beta2= 1./(t*(1-alpha));
0186 
0187     %% Bias term (= Lag. multp. assoc. to the equality constraint)
0188     bias = nu;
0189     
0190     z=Y.*(reshape(W,[1,C*C])*Xf+bias)'-beta1+beta2;
0191 
0192     %% Primal objective
0193     loss_prim = lossPrime(z)+sum(beta2)+lambda*trQ;
0194     
0195     %% Dual objective (at new alpha)
0196     loss=lossDual(alpha);
0197     
0198     %% Primal - dual gap
0199     gap(cc) = loss_prim - (-loss);
0200 
0201     %% Interior-point objective function
0202     obj(cc) = loss +1/t*(-2*sum(log(diag(RR)))...
0203                          -sum(log(alpha))-sum(log(1-alpha)));
0204     
0205     %% IP first order optimality
0206     gg(cc)  = max(abs(g+Y*nu));
0207 
0208     
0209     %% Check validity
0210     if 0
0211     [Va, Da]=eig(A);
0212     da=diag(Da);
0213     lmW = 2*da./(t*(lambda^2-da.^2));
0214 
0215     W0  = Va*diag(lmW)*Va';
0216     trQ0 = sum(2*lambda./(t*(lambda^2-da.^2)));
0217     z0=Y.*(reshape(W0,[1,C*C])*Xf+bias)'-beta1+beta2;
0218     loss_prim0 = lossPrime(z0)+sum(beta2)+lambda*trQ0;
0219 
0220     obj0 = loss +1/t*(-sum(log(lambda-da))-sum(log(lambda+da))...
0221                          -sum(log(alpha))-sum(log(1-alpha)));
0222 
0223     fprintf('!!! |W-W0|=%g dz=%g, dtrQ=%g dloss_prim=%g dobj=%g\n',...
0224             max(abs(rangeof(W-W0))),...
0225             max(abs(rangeof(z-z0))),...
0226             trQ-trQ0,...
0227             loss_prim-loss_prim0,...
0228             obj(cc)-obj0);
0229     end
0230     
0231 
0232     if opt.display>=3
0233       fprintf('[%d] t=%g gap=%g(>=%g) gg=%g y*alpha=%g nu=%g Hmin=%g s=%g obj=%g',...
0234             cc, t, gap(cc), 2*(C+n)/t, gg(cc), Y'*alpha, nu, min(eig(H)), s, obj(cc));
0235     
0236       if cc>1
0237       fprintf(' dloss=(%g/%g)\n', obj(cc)-obj(cc-1), dloss);
0238     else
0239       fprintf('\n');
0240       end
0241     end
0242     
0243 
0244     if gg(cc)<opt.tol | ((gg(cc)<opt.tol*min(100,gap(cc)/opt.tol) | gap(cc)<opt.tol) & s* ...
0245                          max(abs(delta)./abs(alpha))<opt.tolX);
0246 
0247       tlap = cputime-time0;
0248 
0249       if opt.display>=2
0250         fprintf('t=%g: gap=%g gg=%g nsteps=%d time=%g\n',...
0251                 t,gap(cc),gg(cc),cc-cc0,tlap);
0252       end
0253       cc0 = cc;
0254       time0 = cputime;
0255         
0256       break;
0257     end
0258   end
0259   if gap(cc)<opt.tol
0260     break;
0261   else
0262     t = t*opt.tmul;
0263   end
0264 end
0265 
0266 
0267 status = struct('opt',opt,...
0268                 'niter',cc,...
0269                 't',t,...
0270                 'gap',gap,...
0271                 'obj',obj,...
0272                 'beta1',beta1,...
0273                 'beta2',beta2,...
0274                 'alpha', alpha,...
0275                 'time', cputime-time00);
0276 
0277 
0278 
0279 if opt.display>0
0280   fprintf('[%d] gap=%g total time=%g\n', cc, gap(end),cputime-time00);
0281 end
0282 
0283 
0284 function loss = lossPrime(z)
0285 z1 = z(z<0);
0286 z2 = z(z>=0);
0287 loss = sum(log(exp(z1)+1))-sum(z1)+sum(log(1+exp(-z2)));
0288 
0289 loss0=sum(log(1+exp(-z)));
0290 
0291 if ~isinf(loss0) & abs(loss-loss0)>1e-9
0292   error;
0293 end
0294 
0295 
0296 function [loss, g, h] = lossDual(alpha)
0297 
0298 ix = alpha~=0 & alpha~=1;
0299 
0300 loss = zeros(size(alpha));
0301 g    = zeros(size(alpha));
0302 h    = zeros(size(alpha));
0303 
0304 loss(~ix) = 0;
0305 loss(ix) = alpha(ix).*log(alpha(ix)) + (1-alpha(ix)).*log(1-alpha(ix));
0306 
0307 g(~ix)= nan;
0308 g(ix) = log(alpha(ix)./(1-alpha(ix)));
0309 
0310 h(~ix)= nan;
0311 h(ix) = 1./alpha(ix) + 1./(1-alpha(ix));
0312 
0313 loss = sum(loss);
0314 
0315 
0316 function [s,dloss] = lineSearch(alpha,delta,t,Sd1,Sd2,gnu,tolX,display);
0317 snew = 1;
0318 s1 = 0;
0319 s2 = nan;
0320 s_best = 0;
0321 alpha0 = alpha;
0322 loss0 = lossDual(alpha0);
0323 
0324 dloss_best = 0;
0325 cc = 1;
0326 
0327 
0328 while 1
0329   s = snew;
0330   cc = cc +1;
0331   
0332   if display
0333     if s_best > s
0334       fprintf(' %02d: s1=%.2f   s=%.2f * s2=%.2f',cc,s1,s,s2);
0335     elseif s_best < s
0336       fprintf(' %02d: s1=%.2f * s=%.2f   s2=%.2f',cc,s1,s,s2);
0337     else
0338       fprintf(' %02d: _s1=%.2f   s=%.2f   s2=%.2f',cc,s1,s,s2);
0339     end
0340   end
0341   
0342   lm1 = 1+s*Sd1;
0343   lm2 = 1+s*Sd2;
0344 
0345   alpha = alpha0 + s*delta;
0346   
0347   isfeas = 0;
0348   if any(lm1<=0) | any(lm2<=0) | any(alpha<=0) | any(alpha>=1)
0349     ss = '!';
0350     s2 = min(s, s2);
0351     snew = max((s1+s2)/2,s/2);
0352   else
0353     %% feasible
0354     isfeas = 1;
0355      dloss = lossDual(alpha)-loss0...
0356             +1/t*(-sum(log(lm1))-sum(log(lm2))...
0357             -sum(log(1+s*delta./alpha0))-sum(log(1-s*delta./(1-alpha0))))...
0358              +s*gnu'*delta;
0359 
0360         
0361     if dloss < dloss_best
0362       ss = '-';
0363       dloss_best = dloss;
0364       if s_best<s
0365         s1 = s_best;
0366       elseif s<s_best
0367         s2 = s_best;
0368       end
0369       
0370       s_best = s;
0371       
0372     else
0373       ss = '+';
0374       if s_best<s
0375         s2 = s;
0376       elseif s<s_best
0377         s1 = s;
0378       end
0379     end
0380 
0381     if isnan(s2)
0382       snew = s*2;
0383     else
0384       r = 0.5 + rand(1)*0.1-0.05;
0385       snew = s1*r+s2*(1-r);
0386     end
0387 
0388     % obj(cc) = loss +1/t*(-log(det(S1))-log(det(S2))-sum(log(alpha))-sum(log(1-alpha)));
0389 
0390   end
0391 
0392   if display
0393     fprintf(' dloss_best=%g (%s)\n',dloss_best, ss);
0394   end
0395   
0396   if (isfeas & s2-s1<0.01)
0397     break;
0398   end
0399   if (isfeas & isnan(s1) & s<tolX)
0400     break;
0401   end
0402 end
0403 
0404 
0405 function obj = objectiveLocal(alpha0, delta, x, Xf, Y, lambda, t)
0406 %
0407 % [Xl,Yl]=ndgrid(-1:0.1:1);
0408 % xl = [shiftdim(Xl,-1); shiftdim(Yl,-1)];
0409 % r = max(abs(delta))/max(abs(g+Y*nu))
0410 % P = eye(n)-Y*Y'/n;
0411 % objl = objectiveLocal(alpha, [delta, -P*g*r*0.001], xl, Xf, Y, lambda, t);
0412 % figure, surf(Xl, Yl, objl, 'edgecolor','none')
0413 
0414 
0415 
0416 sz = size(x); sz(1)=1;
0417 
0418 obj = squeeze(zeros(sz));
0419 
0420 C = sqrt(size(Xf,1));
0421 
0422 for i=1:prod(size(obj))
0423   alpha = alpha0 + delta*x(:,i);
0424   A = reshape(Xf*(alpha.*Y), [C,C]); A=(A+A')/2;
0425   S1 = lambda*eye(C)-A;
0426   S2 = lambda*eye(C)+A;
0427   
0428   
0429   obj(i) = lossDual(alpha) +1/t*(-sum(log(eig(S1)))-sum(log(eig(S2)))-sum(log(alpha))-sum(log(1-alpha)));
0430 
0431   if any(alpha<0) | any(alpha>1) | any(eig(S1)<0) | any(eig(S2)<0)
0432     obj(i)=nan;
0433   end
0434 end
0435

Generated on Sat 26-Apr-2008 15:48:23 by m2html © 2003