机器学习入门 | 使用梯度下降(Gradient Descent)实现线性回归(Liner Regression)

来源:互联网 发布:软件质量保证承诺书 编辑:程序博客网 时间:2024/06/05 15:43

线性回归算法作为机器学习的入门算法就像高级语言的hello wold,线性回归算法是一种有监督的算法,有监督和无监督简单来说就是:

                        1、有监督即表示对于一个训练样本,我们知道输入和输出,并且可以根据输入输出来训练模型,比如:我有一个抽样统计的全国大学生身高和体重的数据,那么根据身高来预测体重的模型就属于有监督模型,因为训练数据中有身高和体重的对应关系,根据这个对应关系我们在训练模型的过程能直接给提供反馈,让模型知道他预测的准不准。常见的有监督模型:LR、KNN、SVM等

                       2、无监督,我们只有输入的数据没有输出数据,换句话说,对于结果是什么样,我们也不知道。典型的模型就是多属性的聚类分析,在聚类之前我们是不知道数据集中有多少类,这时候就需要用无监督的模型进行预测。常用的无监督模型:K-mearn、PCA等

          线性回归的函数表达式比较简单:

                                   

        其中θ表示的是回归函数参数,我们可以通过样本数据训练模型来得到这个参数。

        X1,X2表示的是特征值,在现实中表示的是有多少特征与对预测结果有关,比如对于预测体重的回归模型,年龄,身高,性别对结果有影响,那么这三个属性就是特征值。 

        那么对于给定的训练数据我们如何确定参数θ 呢?简单来说,由于我们知道训练数据的输入和输出,所以对于参数θ是否选择合适,我们可以把拟合的结果和训练数据的结果相比较就知道参数θ是否选择恰当。这里我们使用代价函数(cost function)也有叫损失函数(loss function)或者 误差函数(error  function)来进行结果的量化:

                                          

                                  

       我们的目标就是通过确定参数θ 来使这个J(θ)最小化。另外,这个求和符号前边的1/2也可以替换成其他值,比如1/4、1/8等,使用1/2的好处使对J(θ)求偏导数可以消除偏导前面的2.

         使J(θ)最小化的方法有很多,常见的有:最小二乘法、梯度下降法、牛顿法等。本次我们使用梯度下降法来求解参数θ。

         先看原理,梯度下降法的原理是对J(θ)求偏导数,而导数就是函数某点的处的斜率,通过对函数求导我们就能知道函数曲线下降的方向,然后通过不断的对参数进行迭代最终会得到J(θ)的最小值,同时也就求出参数θ。

                                               

               根据梯度算法的公式,我们就可以用样本进行训练,来获得参数θ。

          使用梯度算法实现线性回归模型的SQL脚本实现:

         预测背景:通过采样获得大学生的体重数据,特征值为:身高、体重、年龄,回归模型为:

         g(x)=θ0+θ1X1+θ2X2+θ3X3

          其中,X1,X2,X3表示的是特征 身高、体重、年龄,θ0~3表示的是参数值,也就是我们要确定的值。

         梯度法更新参数值:

                                             

        初始化:

                1、令θ0、θ1、θ2、θ3等于0

                2、令a=0.001,a是控制梯度下降的步长,a的取值直接影响梯度算法是否可以收敛,太小收敛慢,取值太大会造成摆动效应,似使算法无法收敛,甚至会导致溢出。

               3、M表示样本的数量

       SQL代码如下:

 declare @c0 float,  @c1 float, @c2 float,@c3 float,@a float     set @c0=0         set @c1=0         set @c2=0         set @c3=0         set @a=0.001             while 1=1    begin        set @c0=@c0-@a(select sum(@c0-@c1*X1+@c2*X2+@c3*X3-Y)/30.0 from F_weight    set @c1=@c1-@a(select sum(@c0-@c1*X1+@c2*X2+@c3*X3-Y)*X1/30.0 from F_weight    set @c2=@c2-@a(select sum(@c0-@c1*X1+@c2*X2+@c3*X3-Y)*X2/30.0 from F_weight    set @c3=@c3-@a(select sum(@c0-@c1*X1+@c2*X2+@c3*X3-Y)*X3/30.0 from F_weight        if abs(@c0-@c1*X1+@c2*X2+@c3*X3-Y)<0.001 OR       abs(@c0-@c1*X1+@c2*X2+@c3*X3-Y)*X1<0.001 OR       abs(@c0-@c1*X1+@c2*X2+@c3*X3-Y)*X2<0.001 OR       abs(@c0-@c1*X1+@c2*X2+@c3*X3-Y)*X3<0.001          break;    end     

    


阅读全文
0 0