随机梯度下降求解svm(MATLAB)

来源:互联网 发布:sftp命令指定端口 编辑:程序博客网 时间:2024/06/05 04:39

本文转载自:http://blog.csdn.net/orangehdc/article/details/38682501

随机梯度下降法(Stochastic Gradient Descent)求解以下的线性SVM模型:


w的梯度为:


传统的梯度下降法需要把所有样本都带入计算,对于一个样本数为n的d维样本,每次迭代求一次梯度,计算复杂度为O(nd) ,当处理的数据量很大而且迭代次数比较多的时候,程序运行时间就会非常慢。

随机梯度下降法每次迭代不再是找到一个全局最优的下降方向,而是用梯度的无偏估计 来代替梯度。每次更新过程为:


由于随机梯度每次迭代采用单个样本来近似全局最优的梯度方向,迭代的步长应适当选小一些以使得随机梯度下降过程尽可能接近于真实的梯度下降法。


下面我用matlab写的一个demo,速度不是很快,跑USPS数据库(二进制格式)csdn下载链接(mat格式),要五分钟,准确率88%左右,效果一般:

[cpp] view plain copy
  1. clear;  
  2. load E:\dataset\USPS\USPS.mat;  
  3. % data format:  
  4. % Xtr n1*dim  
  5. % Xte n2*dim  
  6. % Ytr n1*1   
  7. % Yte n2*1  
  8. % warning: labels must range from 1 to n, n is the number of labels  
  9. % other label values will make mistakes  
  10. u=unique(Ytr);  
  11. Nclass=length(u);  
  12.   
  13. allw=[];allb=[];  
  14. step=0.01;C=0.1;  
  15. param.iterations=1;  
  16. param.lambda=1e-3;  
  17. param.biaScale=1;  
  18. param.t0=100;  
  19.   
  20. tic;  
  21. for classname=1:1:Nclass    
  22.     temp_Ytr=change_label(Ytr,classname);  
  23.     [w,b] = sgd_svm(Xtr,temp_Ytr, param);  
  24.     allw=[allw;w];  
  25.     allb=[allb;b];  
  26.     fprintf('class %d is done \n', classname);  
  27. end  
  28.   
  29. [accuracy predict_label]=my_svmpredict(Xte, Yte, allw, allb);  
  30. fprintf(' accuracy is  %.2f percent.\n' ,  accuracy*100 );  
  31. toc;  


[cpp] view plain copy
  1. function [temp_Ytr] = change_label(Ytr,classname)  
  2. temp_Ytr=Ytr;  
  3. tep2=find(Ytr~=classname);  
  4. tep1=find(Ytr==classname);  
  5. temp_Ytr(tep2)=-1;  
  6. temp_Ytr(tep1)= 1;  

[cpp] view plain copy
  1. function [true_W,b]=sgd_svm(X,Y,param)  
  2. % input:   
  3. % X is n*dim  
  4. % Y is n*1 (label is 1 or 0)  
  5. % output:  
  6. % true_W is dim*1 ,so the score is X*W'+b  
  7. % b      is 1*1 number  
  8. iterations=param.iterations;%10  
  9. lambda=param.lambda;%1e-3  
  10. biaScale=param.biaScale;%0  
  11. t0=param.t0;%100  
  12. t=t0;  
  13.   
  14. w=zeros(1,size(X,2));  
  15. bias=0;  
  16.   
  17. for k=1:1:iterations  
  18.     for i=1:1:size(X,1)  
  19.         t=t+1;  
  20.         alpha = (1.0/(lambda*t));  
  21.         if(Y(i)*(X(i,:)*w'+bias)<1)  
  22.             bias=bias+alpha*Y(i)*biaScale;  
  23.             w=w+alpha*Y(i,1).*X(i,:);  
  24.         end  
  25.     end  
  26. end  
  27. b=bias;  
  28. true_W=w;  


[cpp] view plain copy
  1. function [accuracy predict_label]=my_svmpredict(Xte, Yte, allw, allb)  
  2. % allw is nclass * dim  
  3. % allb is nclass * 1  
  4. % Yte must range from 1 to nclass, other label values will make mistakes  
  5. score = Xte * allw'+repmat(allb',[size(Bte,1),1]);  
  6. [bb  c]=sort(score,2,'descend');  
  7. predict_label=c(:,1);  
  8. temp = predict_label((predict_label-Yte)==0);  
  9. right=size( temp,1 );  
  10. accuracy=right/size(Yte,1);  

0 0