机器学习-线性回归-最小二乘法

来源:互联网 发布:天猫魔盒必装软件 编辑:程序博客网 时间:2024/05/28 15:21

一,背景

1801年,意大利天文学家朱赛普·皮亚齐发现了第一颗小行星谷神星。经过40天的跟踪观测后,由于谷神星运行至太阳背后,使得皮亚齐失去了谷神星的位置。随后全世界的科学家利用皮亚齐的观测数据开始寻找谷神星,但是根据大多数人计算的结果来寻找谷神星都没有结果。时年24岁的高斯也计算了谷神星的轨道。奥地利天文学家海因里希·奥尔伯斯根据高斯计算出来的轨道重新发现了谷神星。

高斯使用的最小二乘法的方法发表于1809年他的著作《天体运动论》中,而法国科学家勒让德于1806年独立发现“最小二乘法”,但因不为时人所知而默默无闻。两人曾为谁最早创立最小二乘法原理发生争执。

1829年,高斯提供了最小二乘法的优化效果强于其他方法的证明,见高斯-马尔可夫定理。(以上文字摘录维基百科)

二,基本介绍

最小二乘法(又称最小平方法)是一种数学优化技术。它通过最小化误差的平方和寻找数据的最佳函数匹配。利用最小二乘法可以简便地求得未知的数据,并使得这些求得的数据与实际数据之间误差的平方和为最小。最小二乘法还可用于曲线拟合。其他一些优化问题也可通过最小化能量或最大化熵用最小二乘法来表达。

三,最小二乘法(Theleast square method)

“最小二乘法”的核心就是保证所有数据偏差的平方和最小。(“平方”的在古时侯的称谓为“二乘”),假设我们有x轴和y轴,如有数组分别对应x,y{[1,40],[2,43],[3,30],[4,53],[5,36]}
我们需要从这些点中拟合出一条直线,从而预测后续的结果值。我们从中任意取两个点,都可以组成一个线性函数y=ax+b
那么问题来了,a和b在什么时候才是属于最优参数呢?一般有三个标准可以选择:
(1)用“残差和最小”确定直线位置是一个途径。但很快发现计算“残差和”存在相互抵消的问题。
(2)用“残差绝对值和最小”确定直线位置也是一个途径。但绝对值的计算比较麻烦。
(3)最小二乘法的原则是以“残差平方和最小”确定直线位置。用最小二乘法除了计算比较方便外,得到的估计量还具有优良特性。这种方法对异常值非常敏感。
最常用的是最小二乘法:使所有的数据的偏差平方和最小-即采用平方损失函数
公式:这里写图片描述
我们现在要做的就是求使得M最小的a和b。请注意这个方程中,我们已知yi和xi
那其实这个方程就是一个以(a,b)为自变量,M为因变量的二元函数。
回想一下高数中怎么对一元函数就极值。我们用的是导数这个工具。那么在二元函数中,
我们依然用导数。只不过这里的导数有了新的名字“偏导数”。偏导数就是把两个变量中的一个视为常数来求导。
通过对M来求偏导数,我们得到一个方程组
这里写图片描述
根据数学知识我们知道,函数的极值点为偏导为0的点。
这里写图片描述
b则等于:y(平均)-a*x(平均)
下面贴上java实现代码:

package com.hh.test;/** *@Author hehai *@Date 2017/7/20 19:51 * 备注:java实习最小二乘法 **/public class Theleastsquaremethod {    private static double a;    private static double b;    private static int num;    /**     * 训练     *     * @param x     * @param y     */    public static void train(double x[], double y[]) {        num = x.length < y.length ? x.length : y.length;        calCoefficientes(x,y);    }    /**     * a=(NΣxy-ΣxΣy)/(NΣx^2-(Σx)^2)     * b=y(平均)-a*x(平均)     * @param x     * @param y     * @return     */    public static void calCoefficientes (double x[],double y[]){        double xy=0.0,xT=0.0,yT=0.0,xS=0.0;        for(int i=0;i<num;i++){            xy+=x[i]*y[i];            xT+=x[i];            yT+=y[i];            xS+=Math.pow(x[i], 2.0);        }        a= (num*xy-xT*yT)/(num*xS-Math.pow(xT, 2.0));        b=yT/num-a*xT/num;    }    /**     * 预测     *     * @param xValue     * @return     */    public static double predict(double xValue) {        System.out.println("a="+a);        System.out.println("b="+b);        return a * xValue + b;    }    public static void main(String args[]) {        double[] X = new double[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};        double[] Y = new double[]{22, 43, 32, 44, 33, 35, 52, 67, 45, 53};        Theleastsquaremethod.train(X, Y);        System.out.println("预测值:"+Theleastsquaremethod.predict(10.0));    }}
原创粉丝点击