【机器学习】线性回归的梯度下降法

来源:互联网 发布:图文软件有哪些 编辑:程序博客网 时间:2024/04/28 12:43

摘要:这是我学习斯坦福大学《机器学习》课程的第一个算法。该算法属于回归模型中最简单的模型——线性回归,使用梯度下降法达到最优拟合。

课程中对机器学习的定义是:Field of study that gives computers the ability to learn without being explicitly programmed.字面意思为不通过显式编程赋予计算机学习的能力。课程开始阐明了监督学习(supervise learning)包含两种问题:回归(regression)和分类(classification)。两种问题都是输入X到输出Y的映射。关于“监督”的含义,简单地说,就是输入和输出数据都已经给出。这类问题的任务是,我们需要从这些输入输出数据中提取一个模式,使得我们有一个新的输入x0时,可以从这个模式中预测输出y0.

一、问题引入

课程使用房屋价格和房屋面积的关系的例子作为引入。
这里写图片描述
房价2万左右吧,类似广州的房价(这个数据只是我自己编的)。根据数据表我们可以作出散点图如下:
这里写图片描述
假如我们看中的房子,面积不在这个列表中,如何估计这间房子的价格呢?我们可以作一个平滑曲线尽可能去拟合所有散点。本例阐明线性回归,于是可用一条直线去拟合:
这里写图片描述
当我们有新的输入x0时,便有预测输出y0。这是高中生都知道的回归分析。
我们可以把具体的问题抽象出来:
房屋销售表:训练集(training set),是回归分析中的输入数据,用x表示;
房屋销售价格:输出数据,用y表示;
拟合的函数:y=h(x),输出模型。
本例输入数据是二维的,而输出是一维模型。其实线性回归可以处理多维的数据。由此可知,回归分析的输出必定是一个连续的模型。

二、问题分析
房屋面积只是影响房屋价格的因素之一,还有其他因素(诸如,地段、环境等)可独立地同时影响房价,把这些因素都提取出来,称为特征变量x,该变量是一个向量,即x=[x0, x1, x2, … , xn].
我们可以假设一个估计函数:
这里写图片描述
其中θ为特征参数,该参数的大小决定了特征变量xi对估计的影响有多大。用向量形式表示为:
这里写图片描述
为了达到最好的拟合效果,必须有一个函数去衡量输出模型的拟合效果。一般把这个函数定义为损失函数(loss function),以下我们称这个函数为J函数:
这里写图片描述
这个损失函数是对估计值与真实值的差的平方进行估计,前面的1/2是为了消去求导时的系数。之后我们要做的就是使拟合的损失最小,即最小化J函数时取得的θ值便是我们理想中的参数。

三、梯度下降法
梯度下降法的主要流程是:
1. 首先对θ赋值,这个值是随机的,也可以赋值为一个零向量;
2. 改变θ的值,使得J函数往下降最快的方向减少,直至函数收敛。
其中梯度方向由J(θ)对θ的偏导数确定。计算过程省略,计算结果为:
这里写图片描述
此算法不断迭代更新θ的值,直到函数收敛为止。其中α是学习因子,其实就是表征下降的速度,α越大,函数下降越快,收敛越快。一般是一个很小数值。迭代其实就是循环,我们不可能永远循环下去,需要一个循环结束条件。循环终止条件有很多种,我们可以把迭代差分小于一个ε值作为结束条件,即:
这里写图片描述

四、Matlab代码实现

% load dataheart_scale = load('../heart_scale');X = heart_scale.heart_scale_inst;Y = heart_scale.heart_scale_label;epsilon = 0.0005;alpha= 0.0001;theta_old = zeros(size(X,2),1);i = 1;figure(1);while 1    minJ_theta(i) = 1/2 * (norm(X * theta_old - Y))^2;  % J函数    theta_new = theta_old - alpha * (X'* X * theta_old - X'* Y);    % 迭代求θ    fprintf('The %dth iteration, minJ_theta = %f, \n', i, minJ_theta(i));    if norm(theta_new - theta_old) < epsilon     % 迭代终止条件        theta_best = theta_new;        break;    end    theta_old = theta_new;    i = i + 1;endplot(minJ_theta);fprintf('The best theta is %f\n', theta_best);

其中heart_scale来源于libsvm数据包,heart_scale_inst包括270个13维的样本,heart_scale_label都是+-1.

运行效果如下:
这里写图片描述
这里写图片描述
有时间还想用Python去实现一把,等待更新吧!

1 0
原创粉丝点击