机器学习

来源:互联网 发布:vb程序设计免费版下载 编辑:程序博客网 时间:2024/04/28 14:22

使用Matlab实现了二分类的SVM,优化技术使用的是Matlab自带优化函数quadprog。

只为检查所学,更为熟悉;不为炫耀。也没有太多时间去使用更多的优化方法。

[plain] view plaincopy
  1. function model = svm0311(data,options)  
  2. %SVM0311  解决2分类的SVM方法,优化使用matlab优化工具箱quadprog函数实现  
  3. %by LiFeiteng     email:lifeiteng0422@gmail.com  
  4. %Reference: stptool  
  5. %           Pattern Recognition and Machine Learning P333 7.32-7.37  
  6.   
  7. % input aruments  
  8. %-------------------------------------------  
  9. tic  
  10.   
  11. data=c2s(data);  
  12. [dim,num_data]=size(data.X);  
  13.   
  14. if nargin < 2, options=[]; else options=c2s(options); end  
  15. if ~isfield(options,'ker'), options.ker = 'linear'; end  
  16. if ~isfield(options,'arg'), options.arg = 1; end  
  17. if ~isfield(options,'C'), options.C = inf; end  
  18. if ~isfield(options,'norm'), options.norm = 1; end  
  19. if ~isfield(options,'mu'), options.mu = 1e-12; end  
  20. if ~isfield(options,'eps'), options.eps = 1e-12; end  
  21.   
  22. X = data.X;  
  23. t = data.y;  
  24. t(t==2) = -1;  
  25.   
  26. % Set up QP task  
  27. %----------------------------  
  28. K = X'*X;  
  29. T = t'*t;% 注意t是横向量  
  30. H = K.*T;  
  31. save('H0311.mat','H')  
  32. H = H + options.mu*eye(size(H));  
  33.   
  34. f = -ones(num_data,1);  
  35. Aeq = t;  
  36. beq = 0;  
  37. lb = zeros(num_data,1);  
  38. ub = options.C*ones(num_data,1);  
  39.   
  40. x0 = zeros(num_data,1);  
  41. qp_options = optimset('Display','off');  
  42. [Alpha,fval,exitflag] = quadprog(H, f,[],[], Aeq, beq, lb, ub, x0, qp_options);  
  43.   
  44. inx_sv = find(Alpha>options.eps);  
  45.   
  46. % compute bias  
  47. %--------------------------  
  48. % take boundary (f(x)=+/-1) support vectors 0 < Alpha < C  
  49. b = 0;  
  50. inx_bound = find( Alpha > options.eps & Alpha < (options.C - options.eps));  
  51. Nm = length(inx_bound);  
  52. for n = 1:Nm  
  53.     tmp = 0;  
  54.     for m = 1:length(inx_sv) %PRML7.37  
  55.         tmp = tmp+Alpha(inx_sv(m))*t(inx_sv(m))*K(inx_bound(n),inx_sv(m));  
  56.     end  
  57.     b = b + t(inx_bound(n))-tmp;  
  58. end  
  59. b = b/Nm;  
  60. model.b = b;     
  61.       
  62. %-----------------------------------------  
  63. w = zeros(dim,1);  
  64. for i = 1:num_data     
  65.     w = w+ Alpha(i)*t(i)*X(:,i);%PRML 7.29  
  66. end  
  67.   
  68. margin = 1/norm(w);  
  69. %-------------------------------------------  
  70. %此处与stprtool保持接口一致  用于画图展示等  
  71. model.Alpha = Alpha( inx_sv );  
  72. model.sv.X = data.X(:,inx_sv );  
  73. model.sv.y = data.y(inx_sv );  
  74. model.sv.inx = inx_sv;  
  75. model.nsv = length( inx_sv );  
  76. model.margin = margin;  
  77. model.exitflag = exitflag;  
  78. model.options = options;  
  79. model.kercnt = num_data*(num_data+1)/2;  
  80. model.trnerr = cerror(data.y,svmclass(data.X, model));  
  81. model.fun = 'svmclass';  
  82.   
  83. model.W = model.sv.X*model.Alpha;  
  84.   
  85. % used CPU time  
  86. model.cputime=toc;  
  87.   
  88. return;  

0 0
原创粉丝点击