基于梯度下降算法求解线性回归
来源:互联网 发布:windows movie make 编辑:程序博客网 时间:2024/05/16 13:58
线性回归(Linear Regression)
梯度下降算法在机器学习方法分类中属于监督学习。利用它可以求解线性回归问题,计算一组二维数据之间的线性关系,假设有一组数据如下下图所示
其中X轴方向表示房屋面积、Y轴表示房屋价格。我们希望根据上述的数据点,拟合出一条直线,能跟对任意给定的房屋面积实现价格预言,这样求解得到直线方程过程就叫线性回归,得到的直线为回归直线,数学公式表示如下:
二:梯度下降 (Gradient Descent)
三:代码实现
数据读入
public List<DataItem> getData(String fileName) { List<DataItem> items = new ArrayList<DataItem>(); File f = new File(fileName); try { if (f.exists()) { BufferedReader br = new BufferedReader(new FileReader(f)); String line = null; while((line = br.readLine()) != null) { String[] data = line.split(","); if(data != null && data.length == 2) { DataItem item = new DataItem(); item.x = Integer.parseInt(data[0]); item.y = Integer.parseInt(data[1]); items.add(item); } } br.close(); } } catch (IOException ioe) { System.err.println(ioe); } return items;}
归一化处理
public void normalization(List<DataItem> items) { float min = 100000; float max = 0; for(DataItem item : items) { min = Math.min(min, item.x); max = Math.max(max, item.x); } float delta = max - min; for(DataItem item : items) { item.x = (item.x - min) / delta; }}
梯度下降
public float[] gradientDescent(List<DataItem> items) { int repetion = 1500; float learningRate = 0.1f; float[] theta = new float[2]; Arrays.fill(theta, 0); float[] hmatrix = new float[items.size()]; Arrays.fill(hmatrix, 0); int k=0; float s1 = 1.0f / items.size(); float sum1=0, sum2=0; for(int i=0; i<repetion; i++) { for(k=0; k<items.size(); k++ ) { hmatrix[k] = ((theta[0] + theta[1]*items.get(k).x) - items.get(k).y); } for(k=0; k<items.size(); k++ ) { sum1 += hmatrix[k]; sum2 += hmatrix[k]*items.get(k).x; } sum1 = learningRate*s1*sum1; sum2 = learningRate*s1*sum2; // 更新 参数theta theta[0] = theta[0] - sum1; theta[1] = theta[1] - sum2; } return theta;}
价格预言
public float predict(float input, float[] theta) { float result = theta[0] + theta[1]*input; return result;}
线性回归图
public void drawPlot(List<DataItem> series1, List<DataItem> series2, float[] theta) { int w = 500; int h = 500; BufferedImage plot = new BufferedImage(w, h, BufferedImage.TYPE_INT_ARGB); Graphics2D g2d = plot.createGraphics(); g2d.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON); g2d.setPaint(Color.WHITE); g2d.fillRect(0, 0, w, h); g2d.setPaint(Color.BLACK); int margin = 50; g2d.drawLine(margin, 0, margin, h); g2d.drawLine(0, h-margin, w, h-margin); float minx=Float.MAX_VALUE, maxx=Float.MIN_VALUE; float miny=Float.MAX_VALUE, maxy=Float.MIN_VALUE; for(DataItem item : series1) { minx = Math.min(item.x, minx); maxx = Math.max(maxx, item.x); miny = Math.min(item.y, miny); maxy = Math.max(item.y, maxy); } for(DataItem item : series2) { minx = Math.min(item.x, minx); maxx = Math.max(maxx, item.x); miny = Math.min(item.y, miny); maxy = Math.max(item.y, maxy); } // draw X, Y Title and Aixes g2d.setPaint(Color.BLACK); g2d.drawString("价格(万)", 0, h/2); g2d.drawString("面积(平方米)", w/2, h-20); // draw labels and legend g2d.setPaint(Color.BLUE); float xdelta = maxx - minx; float ydelta = maxy - miny; float xstep = xdelta / 10.0f; float ystep = ydelta / 10.0f; int dx = (w - 2*margin) / 11; int dy = (h - 2*margin) / 11; // draw labels for(int i=1; i<11; i++) { g2d.drawLine(margin+i*dx, h-margin, margin+i*dx, h-margin-10); g2d.drawLine(margin, h-margin-dy*i, margin+10, h-margin-dy*i); int xv = (int)(minx + (i-1)*xstep); float yv = (int)((miny + (i-1)*ystep)/10000.0f); g2d.drawString(""+xv, margin+i*dx, h-margin+15); g2d.drawString(""+yv, margin-25, h-margin-dy*i); } // draw point g2d.setPaint(Color.BLUE); for(DataItem item : series1) { float xs = (item.x - minx) / xstep + 1; float ys = (item.y - miny) / ystep + 1; g2d.fillOval((int)(xs*dx+margin-3), (int)(h-margin-ys*dy-3), 7,7); } g2d.fillRect(100, 20, 20, 10); g2d.drawString("训练数据", 130, 30); // draw regression line g2d.setPaint(Color.RED); for(int i=0; i<series2.size()-1; i++) { float x1 = (series2.get(i).x - minx) / xstep + 1; float y1 = (series2.get(i).y - miny) / ystep + 1; float x2 = (series2.get(i+1).x - minx) / xstep + 1; float y2 = (series2.get(i+1).y - miny) / ystep + 1; g2d.drawLine((int)(x1*dx+margin-3), (int)(h-margin-y1*dy-3), (int)(x2*dx+margin-3), (int)(h-margin-y2*dy-3)); } g2d.fillRect(100, 50, 20, 10); g2d.drawString("线性回归", 130, 60); g2d.dispose(); saveImage(plot);}
四:总结
本文通过最简单的示例,演示了利用梯度下降算法实现线性回归分析,使用更新收敛的算法常被称为LMS(Least Mean Square)又叫Widrow-Hoff学习规则,此外梯度下降算法还可以进一步区分为增量梯度下降算法与批量梯度下降算法,这两种梯度下降方法在基于神经网络的机器学习中经常会被提及,对此感兴趣的可以自己进一步探索与研究。
只分享干货,不止于代码
阅读全文
5 0
- 基于梯度下降算法求解线性回归
- 梯度下降法求解线性回归问题
- 基于梯度下降法实现线性回归算法
- 一元线性回归与梯度下降算法
- 线性回归和梯度下降算法
- 线性回归与梯度下降算法
- 线性回归与梯度下降算法
- 线性回归与梯度下降算法(1)
- 线性回归、梯度下降算法与 tensorflow
- 线性回归及梯度下降算法详解
- 线性回归与梯度下降算法
- 线性回归&梯度下降
- 梯度下降法求解线性回归之python实现
- 梯度下降法求解线性回归之matlab实现
- 梯度下降求解逻辑回归
- 基于梯度下降法的线性回归模型
- 基于matlab的梯度下降法实现线性回归
- [笔记]线性回归&梯度下降
- select设置只读
- 【java】:java实体类
- jsp的内置对象
- Redis的安装
- @ModelAttribute注解的作用
- 基于梯度下降算法求解线性回归
- Android系统 boot.img 结构
- DataFrame如何根据一列来计算另一列出现的次数
- HDU-1069-Monkey and Banana-DP
- Redis基础
- 随机优惠券发放 金额越大 概率越小金额越小概率越大算法
- 挑战程序竞赛系列(12):2.5最小生成树
- unix://localhost:80: Permission denied 问题解决
- C语言实现单链表面试题--基础篇