随机梯度下降和批量梯度下降的简单代码实现

来源:互联网 发布:scd数据库期刊 编辑:程序博客网 时间:2024/05/23 15:42

        最近刚刚开始看斯坦福的机器学习公开课,第一堂课讲到了梯度下降,然后就简单实现了一下。本人学渣一枚,如有错误,敬请指出。

     

/** * Created by Administrator on 2016/4/16 0016. */public class GradientDescent {    private static double[][] data = {            {3.8, 192.0314202},            {3.5, 194.1168421},            {4, 195.1114837},            {4.4, 197.7640977},            {4.1, 196.8811122},            {4.6, 202.9643527},            {3.6, 191.245283},            {3.2, 189.2631579},            {3.4, 189.9758454},            {3, 187.6717949},            {3.9, 193.5243902},            {3.1, 189.2704403},            {2.2, 177.248366},            {3.7, 189.296875},            {3.3, 189.5043478},            {4.2, 199.6857143},    };    //根据excel得到的回归方程:y = 9.3581x + 158.3,数据来自日常的一个项目    public static void main(String[] args) {        stochastic(data);        batch(data);    }    /*    * 当rate = 0.01时    * 循环2000左右的时候值就不变化了    * parameter is 157.90981024717982 9.482991891267803    * error is 47.897064097242335    *    * 当rate = 0.001时    * 循环30000,最后结果几乎不变    * parameter is 158.25947581125462 9.36980772136795    * error is 47.7491293901012    * */    private static void stochastic(double[][] data) {        double[] p = {0, 0};//初始化参数为0        double rate = 0.001;        for (int i = 0; i < 30000; i++) {            for (double[] aData : data) {                double h = 0, err;                h += p[0] + p[1] * aData[0];                err = aData[1] - h;                //根据每一条数据更新参数                p[0] += rate * err * 1;                p[1] += rate * err * aData[0];            }        }        System.out.println("parameter is " + p[0] + " " + p[1]);        double error = 0;        for (double[] aData : data) {            error += Math.pow(aData[1] - (p[0] + p[1] * aData[0]), 2);        }        System.out.println("error is " + error);    }    /*    * rate = 0.001, 循环次数等于30000时,所计算的结果和excel计算的几乎完全一致    *    * parameter is 158.299201608832 9.358074090590318    * error is 47.74825830393555    *    * 批量梯度下经确实更加准确    * */    private static void batch(double[][] data) {        double[] p = {0, 0};        double rate = 0.001;        for (int i=0;i<50000;i++){            double err1 = 0;            double err2 = 0;            for (double[] aData:data){                double h=0;                h=p[0]+p[1]*aData[0];                err1 += aData[1] - h;                err2 += (aData[1]-h)*aData[0];            }            //遍历整个数据集之后再更新参数            p[0] += rate*err1;            p[1] += rate*err2;        }        System.out.println("parameter is " + p[0] + " " + p[1]);        double error = 0;        for (double[] aData : data) {            error += Math.pow(aData[1] - (p[0] + p[1] * aData[0]), 2);        }        System.out.println("error is " + error);    }}


0 0
原创粉丝点击