常用算法之:1、最小二乘法(1)

来源:互联网 发布:seo基础理论 编辑:程序博客网 时间:2024/06/05 03:52

深度学习发展到如今的地位,离不开下面这 6 段代码。本文介绍了这些代码的创作者及其完成这些突破性成就的故事背景。每个故事都有简单的代码示例,读者们可以在 FloydHub 和 GitHub 找到相关代码。

最小二乘法

所有的深度学习算法都始于下面这个数学公式(我已将其转成 Python 代码)

  1. # y = mx + b %一个一次线性方程  
  2. # m is slope, b is y-intercept %斜率和截距  
  3. def compute_error_for_line_given_points(b,m,coordinates):  
  4.    totalError = 0  
  5.    for i in range(0,len(coordinates)):  
  6.       x = coordinates[i][0]  
  7.       y = coordinates[i][1]  
  8.       totalError += (y-(m*x+b))**2  
  9.    return totalError/float(len(coordinates))  
  10.   
  11. # example  
  12. compute_error_for_line_given_points(1,2,[[3,6],[6,9],[12,18]])  

最小二乘法在 1805 年由 Adrien-Marie Legendre 首次提出(1805, Legendre),这位巴黎数学家也以测量仪器闻名。他极其痴迷于预测彗星的方位,坚持不懈地寻找一种可以基于彗星方位历史数据计算其轨迹的算法。

他尝试了许多种算法,一遍遍试错,终于找到了一个算法与结果相符。Legendre 的算法是首先预测彗星未来的方位,然后计算误差的平方,最终目的是通过修改预测值以减少误差平方和。而这也正是线性回归的基本思想。

读者可以在 Jupyter notebook 中运行上述代码来加深对这个算法的理解。m 是系数,b 是预测的常数项,coordinates 是彗星的位置。目标是找到合适的 m 和 b 使其误差尽可能小。

Python那些事——这6段代码,解释了什么是编程!

这是深度学习的核心思想:给定输入值和期望的输出值,然后寻找两者之间的相关性。

1、概念简介

根据维基百科的说明:

         最小二乘法(又称最小平方法)是一种数学优化技术。它通过最小化误差的平方和寻找数据的最佳函数匹配。利用最小二乘法可以简便地求得未知的数据,并使得这些求得的数据与实际数据之间误差的平方和为最小。

         看了之后一头雾水对不对,是的,任何人看着一段不知道在说啥。下面举个例子,就很好懂了。

         针对线性最小二乘法即直线拟合,如下图(来自维基百科)所示:

         

         透过这张图,我想大家一定能理解,我们用最小二乘法来做什么事情,即:

         根据已有的数据(图中的蓝色点),来做出一条最贴近数据发展趋势的直线。通过这条直线,我们可以对未来的数据进行预测,因为基本会落在这条直线附近。

         当然了,最小二乘法不只是直线,还可以是曲线,本文不讨论。

2、求解直线方程

(1)最小二乘法原理:

    在我们研究两个变量(x,y)之间的相互关系时,通常可以得到一系列成对的数据(x1,y1)(x2,y2)(....)(xm,ym);将这些数据描绘在x-y直角坐标系中,若发现这些点在一条直线附近,可以令这条直线方程如(式1-1)。

  yi = a*xi + b                        (式1-1)

  其中:a、b 是任意实数

 

(2)常见拟合曲线:

       直线:    y=a*x+b

      多项式:最小二乘法数据拟合一般次数不易过高.

      双曲线:  y=a/x+b

      指数曲线: y=a*e^b

      matlab中函数:P=polyfit(x,y,n)

     polyval(P,t):返回n次多项式在t处的值

 

我们现在要做的,就是求解直线方程。

假设已知有N个点具有线性相关关系,(x1,y1), (x2,y2),…,(xn,yn)且实数xi不全相等,

设这条直线方程为:  y = m·x + c ,求斜率m和截距c,使得所有点相对于该直线的偏差平方和达到最小。

         解:设实数xi不全相等,所求直线方程为:y= a·x + b

                   要确定a,b,使得函数f(a,b) =∑ni=(yi - (a*xi+b))2最小,

 

其中,a和b的计算公式如下:

 

本文对于推导过程简单讲述,网上都有。

3、线性回归

  线性回归假设数据集中特征与结果存在着线性关系;

  等式:y = mx + c

y为结果,x为特征,m为系数,c为误差在数学中m为梯度c为截距

  这个等式为我们假设的,我们需要找到mc使得mx+c得到的结果与真实的y误差最小,这里使用平方差来衡量估计值与真实值得误差(如果只用差值就可能会存在负数);用于计算真实值与预测值的误差的函数称为:平方损失函数(squard loss function;这里用L表示损失函数,所以有:

  整个数据集上的平均损失为:

  我们要求得最匹配的mc使得L最小;
数学表达式可以表示为:

  最小二乘法用于求目标函数的最优值,它通过最小化误差的平方和寻找匹配项所以又称为:最小平方法;这里将用最小二乘法用于求得线性回归的最优解;

最小二乘法

  为了方便讲清楚最小二乘法推导过程这里使用,数据集有1…N个数据组成,每个数据由、构成,x表示特征,y为结果;这里将线性回归模型定义为:

平均损失函数定义有:


  要求得L的最小,其关于cm的偏导数定为0,所以求偏导数,得出后让导数等于0,并对cm求解便能得到最小的L此时的cm便是最匹配该模型的;

关于c偏导数:

因为求得是关于c的偏导数,因此把L的等式中不包含c的项去掉得:


整理式子把不包含下标n的往累加和外移得到:



c求偏导数得:


关于m的偏导数:

求关于m的偏导数,因此把L等式中不包含项去掉得:


  整理式子把不包含下标n的往累加和外移得到:


m求偏导数得:


令关于c的偏导数等于0,求解:


从上求解得到的值可以看出,上面式子中存在两个平均值,因此该等式也可以改写成:


令关于m的偏导数等于0,求解:
  关于m的偏导数依赖于c,又因为已经求得了关于c偏导数的解,因此把求关于c偏导数的解代数关于m的偏导数式子得:




合并含有m的项化简:


求解:



为了简化式子,再定义出:


C#算法代码如下:

//-------------------------------------------------------------
//
功能 : 最小二乘法直线拟合 y = a·x+ b计算系数a b
//
参数 : x –横坐标数组
//       y --  
纵坐标数组
//       num
是数组包含的元素个数,x[]y[]的元素个数必须相等
//       a,b
都是返回值
//
返回 : 拟合计算成功返回true, 拟合计算失败返回false
//-------------------------------------------------------------
bool leastSquareLinearFit(float x[], float y[], const int num, float &a,float &b)
{
    float sum_x2 = 0.0;
    float sum_y  = 0.0;
    float sum_x  = 0.0;
    float sum_xy = 0.0;

    try

      {
        for (int i = 0; i < num; ++i)

           {
            sum_x2 += x[i]*x[i];
            sum_y  += y[i];
            sum_x  += x[i];
            sum_xy += x[i]*y[i];
        }
    }

     catch (...)

      {
        return false;
    }
    a = (num*sum_xy - sum_x*sum_y)/(num*sum_x2 - sum_x*sum_x);
    b = (sum_x2*sum_y - sum_x*sum_xy)/(num*sum_x2-sum_x*sum_x);

    return true;
}

 数据样本:

x

float temp[96] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.46667, 11.4667, 31.6, 52.7333, 80.3333, 116.333, 156.6, 199.4, 242.2, 283.4, 329.2, 379.333, 431.333, 482.6, 541, 594.4, 643.533, 692.133, 736.267, 772.667, 810.133, 841.867, 868.2, 892.4, 917.667, 939.8, 954.667, 969, 976.8, 983.4, 987.467, 994.933, 1023.67, 875.2, 873.933, 758.8, 678.2, 515.867, 782.533, 908.8, 779.2, 831.4, 645.533, 734.067, 679.533, 610.267, 565.067, 512.467, 462, 405.2, 354.133, 302, 247.8, 191.533, 140, 94.2667, 57.5333, 25.9333, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; // x


y

float tempy[96] = {0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 0.595, 1.785, 2.57833, 3.927, 5.79233, 7.379, 9.48133, 11.1473, 12.4167, 13.6627, 16.193701, 18.248699, 19.042, 19.042, 19.105301, 16.6383, 17.240999, 14.631, 11.8217, 11.663, 12.155, 15.488, 21.859301, 19.32, 19.042, 19.6133, 21.105, 22.9937, 20.827299, 23.858299, 23.0333, 19.2883, 15.6937, 21.5893, 23.802999, 20.518299, 21.5893, 17.907301, 17.971001, 17.574301, 16.781, 15.5513, 12.3773, 10.2747, 8.60867, 6.86333, 5.39567, 3.88767, 2.856, 2.142, 2.142, 0.952, 0.952, 0.952, 0.952, 0.952, 0.952, 0.952, 0.952, 0.952, 0.952, 0.952, 0.952, 0.952, 0.952 }; // y    


计算结果:

a = 0.0215136

b = 0.608488


效果如下:


原创粉丝点击