梯度下降法实现(step-by-step)

来源:互联网 发布:ubuntu 查看samba用户 编辑:程序博客网 时间:2024/05/24 06:37

机器学习入门:线性回归及梯度下降对原理做了详细的说明

本文主要记录实现时的代码

1.实验数据

采用随机的方式,生成(X,Y)数据对

k=10;x=zeros(k,1);y=zeros(k,1);for i=1:kx(i,1)=30+i*5;y(i,1)=(8+rand)*x(i,1);endplot(x,y,'.r');

2.theta0,theta1,J函数

当theta0=0时

x=[1 2 3];y=[1 2 3];k=5;c1=zeros(1,k);Jf=zeros(1,k);for i=1:kc1(1,i)=(i-1)*0.5;endfor i=1:kfor j=1:3Jf(1,i)=Jf(1,i)+(c1(1,i)*x(1,j)-y(1,j))^2;endJf(1,i)=Jf(1,i)/6;endplot(c1,Jf);function value=Jfunction(c0,c1,x,y)k=size(x,1);value=0;for j=1:kvalue=value+(c1*x(1,j)+c0-y(1,j))^2;endvalue=value/(2*k);

当theta0不恒为0时

clear;clc;x=[1 2 3];y=[1 2 3]; c0=linspace(-10, 10, 100);    c1=linspace(-1, 4, 100); Jf=zeros(length(c0),length(c1));  for i=1:length(c0)      for j=1:length(c1)    Jf(i,j)=Jfunction(c0(1,i),c1(1,j),x,y);      end  end  surf(c0,c1,Jf);  contour(c0,c1,Jf,logspace(-2, 3, 20));


3.梯度下降法和随机梯度下降法

clear;clc;close all;    k=10;      % a=0.000001;      a=0.001;      x=zeros(k,1);      y=zeros(k,1);      for i=1:k      x(i,1)=30+i*5;      y(i,1)=(8+rand)*x(i,1);      end      plot(x,y,'.r');      [c0,c1,count]=CalGradient(x,y,a);      [c2,c3,count2]=Gradient_descent_rand(x,y,a);        x=35:5:80;      y1=c1*x+c0;      y2=c3*x+c2;      hold on      plot(x,y1);      plot(x,y2,'r');    count      c0      c1      count2    c2      c3  
function [theta0,theta1,count]=CalGradient(X,Y,a);  theta0=0;  theta1=0;  t0=0;  t1=0; count=0;while(1)      for i=1:size(X,2)          t0=t0+(theta0+theta1*X(i,1)-Y(i,1))*1;          t1=t1+(theta0+theta1*X(i,1)-Y(i,1))*X(i,1);      end      old_theta0=theta0;      old_theta1=theta1;      theta0=theta0-a*t0;    theta1=theta1-a*t1;      t0=0;      t1=0;      if(sqrt((old_theta0-theta0)^2+(old_theta1-theta1)^2)<0.000001)         break;      end      count=count+1;end  
    function [theta0,theta1,count]=Gradient_descent_rand(X,Y,a);        theta0=0;        theta1=0;      count=0;      flag=true;    while(flag)      for i=1:size(X,2)         old_theta0=theta0;          old_theta1=theta1;          theta0=theta0-a*(theta0+theta1*X(i,1)-Y(i,1))*1;          theta1=theta1-a*(theta0+theta1*X(i,1)-Y(i,1))*X(i,1);         count=count+1;               if(sqrt((old_theta0-theta0)^2+(old_theta1-theta1)^2)<0.000001)              flag=false;             break;             end     end      end
a:学习率,太小的话,收敛较慢,太大可能无法收敛

count =

    12


c0 =

    0.2371


c1 =

    8.2969


count2 =

    12


c2 =

    0.2373


c3 =

    8.2969


0 0
原创粉丝点击