线性回归的推导与java代码
来源:互联网 发布:淘宝店铺怎么做淘宝客 编辑:程序博客网 时间:2024/06/03 13:56
1.1线性回归的数学表达
1.2线性回归的java代码实现
java实现一元线性回归:
public class DataPoint { public float x; public float y; public DataPoint(float x,float y){ //DataPoint类的构造函数 this.x = x; this.y = y; }}//RegressionLine类,用于处理一元线性回归问题import java.math.BigDecimal;import java.util.ArrayList;public class RegressionLine { private float sumX = 0;//训练集x的和 private float sumY = 0;//训练集y的和 private float sumXX = 0;//x*x的和 private float sumYY = 0;//y*y的和 private float sumXY = 0;//x*y的和 private float sumDeltaY;//y与yi的差 private float sumDeltaY2; // sumDeltaY的平方和 //误差 private float sse;//残差平方和 private float sst;//总平方和 private float E; private float[] xy; private ArrayList<String> listX;//x的链表 private ArrayList<String> listY;//y的链表 private double XMin,XMax,YMin,YMax; private float a0;//线性系数a0private float a1;//线性系数a1 private int pn; //训练集数据个数 private boolean coefsValid;//类RegressionLine的构造函数 public RegressionLine(){ XMax = 0; YMax = 0; pn = 0; xy = new float[2]; listX = new ArrayList<>(); listY = new ArrayList<>(); } //类RegressionLine的有参构造函数 public RegressionLine(DataPoint data[]){ pn = 0; xy = new float[2]; listX = new ArrayList(); listY = new ArrayList(); for(int i = 0;i<data.length;++i){ addDatapoint(data[i]);//添加数据集的方法addDatapoint } } public int getDataPointCount(){ return pn; } public float getA0(){ validateCoefficients(); return a0; } public float getA1(){ validateCoefficients(); return a1; } public double getSumX(){ return sumX; } public double getSumY() { return sumY; } public double getSumXX() { return sumXX; } public double getSumYY() { return sumYY; } public double getSumXY() { return sumXY; } public double getXMin() { return XMin; } public double getXMax() { return XMax; } public double getYMax() { return YMax; } public double getYMin() { return YMin; } //添加训练集数据的方法 public void addDatapoint(DataPoint dataPoint){ sumX += dataPoint.x; sumY += dataPoint.y; sumXX += dataPoint.x*dataPoint.x; sumYY += dataPoint.y*dataPoint.y; sumXY += dataPoint.x*dataPoint.y; if(dataPoint.x > XMax){ XMax = dataPoint.x; } if (dataPoint.y > YMax){ YMax = dataPoint.y; } xy[0] = dataPoint.x ;//? xy[1] = dataPoint.y ;//? if(dataPoint.x !=0 && dataPoint.y != 0){ System.out.print("("+xy[0]+","); System.out.println(xy[1]+")"); try{ listX.add(pn,String.valueOf(xy[0])); listY.add(pn,String.valueOf(xy[1])); }catch (Exception e){ e.printStackTrace(); } } ++pn; coefsValid = false; } //计算预测值y的方法 public float at(float x){ if(pn < 2) return Float.NaN; validateCoefficients(); return a0 + a1 * x; } //重置此类的方法 public void reset(){ pn = 0; sumX = sumY = sumXX = sumXY = 0; coefsValid = false; } //计算系数a0,a1的方法 private void validateCoefficients(){ if (coefsValid) return; if (pn >= 2){ float xBar = (float)sumX/pn; float yBar = (float)sumY/pn; a1 = (float)((pn*sumXY - sumX*sumY)/(pn *sumXX - sumX*sumX)); a0 = (yBar - a1*xBar); } else { a0 = a1 = Float.NaN; } coefsValid = true; } //计算判定系数R^2的方法 public double getR(){ for (int i = 0;i < pn;i++){ float Yi = Float.parseFloat(listY.get(i).toString()); float Y = at(Float.parseFloat( listX.get(i).toString())); float deltaY = Yi - Y; float deltaY2 = deltaY*deltaY; sumDeltaY2 += deltaY2; float deltaY1 = (Yi - (float) (sumY/pn))*(Yi - (float) (sumY/pn)) ; sst += deltaY1; } //sst = sumYY - (sumY*sumY)/pn; E = 1 - sumDeltaY2/sst; return round(E,4); } //返回经处理过的判定系数的方法 public double round(double v,int scale){ BigDecimal b = new BigDecimal(Double.toString(v)); BigDecimal one = new BigDecimal("1"); return b.divide(one,scale,BigDecimal.ROUND_HALF_UP).floatValue(); }}//测试类import java.util.Scanner;public class LinearRegression { private static final int MAX_POINTS = 10;//定义最大的训练集数据个数 private double E; public static void main(String args[]){ //测试主方法 DataPoint[] data = new DataPoint[MAX_POINT]; //创建数据集对象数组data[]//创建线性回归类对象line,并且初始化类 RegressionLine line = new RegressionLine(constructDates(data));//调用printSums方法打印Sum变量 printSums(line);//调用printLine方法并打印线性方程 printLine(line); } //构建数据方法 private static DataPoint[] constructDates(DataPoint date[]){ Scanner sc = new Scanner(System.in); float x,y; for(int i = 0;i<3;i++){ System.out.println("请输入第"+(i+1)+"个x的值:"); x = sc.nextFloat(); System.out.println("请输入第"+(i+1)+"个y的值:"); y = sc.nextFloat(); date[i] = new DataPoint(x,y); } return date; } //打印Sum数据方法 private static void printSums(RegressionLine line){ System.out.println("\n数据点个数 n = "+ line.getDataPointCount()); System.out.println("\nSumX = "+line.getSumX()); System.out.println("SumY = "+line.getSumY()); System.out.println("SumXX = "+line.getSumXX()); System.out.println("SumXY = "+line.getSumXY()); System.out.println("SumYY = "+line.getSumYY()); } //打印回归方程方法 private static void printLine(RegressionLine line){ System.out.println("\n回归线公式:y = "+line.getA1() +"x + " + line.getA0()); //System.out.println("Hello World!"); System.out.println("误差: R^2 = " + line.getR()); }}
测试结果:
输入测试数据如下
程序运行结果为:
阅读全文
0 0
- 线性回归的推导与java代码
- 【西瓜书】线性回归在回归,二分类,多分类问题上的应用与推导
- logsit回归代码的推导
- logsit回归代码的推导
- 线性回归---公式推导
- 逻辑回归的数学推导及java代码实现
- 线性回归的代码实现
- 1.线性回归的推导--梯度下降法
- 回归 ---- 线性回归,多元回归与逻辑回归的关系
- 线性回归与logistic回归的思路
- 线性回归的java实现
- 线性回归与岭回归python代码实现
- 线性回归中的最小二乘法,L1,L2推导
- 二元线性回归最小二乘法公式推导
- 线性回归推导过程和实例
- 线性回归代码matlab
- R语言--线性回归(2)回归模型推导
- 线性回归与逻辑回归
- 20171211Link
- 一个故事让你彻底理解 Https
- delphi 权限控制(delphi TActionList方案)
- Lucene索引原理
- 判断结构
- 线性回归的推导与java代码
- http请求响应状态码
- Linux下的静态库和共享库的创建和使用
- 代理人模式
- StringUtils类中isEmpty与isBlank的区别(空格的体现)
- MySQL--SET
- 2017哈理工 低年级组院赛初赛 G-做游戏 【水题】
- 600台自动售货机的管理系统是这样 | 新零售「良品铺子」
- 使用idea commit代码时遇到的detached head 问题的解决