线性回归的推导与java代码

来源:互联网 发布:淘宝店铺怎么做淘宝客 编辑:程序博客网 时间:2024/06/03 13:56

1.1线性回归的数学表达

这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述

1.2线性回归的java代码实现

java实现一元线性回归:

public class DataPoint {    public float x;    public float y;    public DataPoint(float x,float y){  //DataPoint类的构造函数        this.x = x;        this.y = y;    }}//RegressionLine类,用于处理一元线性回归问题import java.math.BigDecimal;import java.util.ArrayList;public class RegressionLine {    private float sumX = 0;//训练集x的和    private float sumY = 0;//训练集y的和    private float sumXX = 0;//x*x的和    private float sumYY = 0;//y*y的和    private float sumXY = 0;//x*y的和    private float sumDeltaY;//y与yi的差    private float sumDeltaY2; // sumDeltaY的平方和    //误差    private float sse;//残差平方和    private float sst;//总平方和    private float E;    private float[] xy;    private ArrayList<String> listX;//x的链表    private ArrayList<String> listY;//y的链表    private double XMin,XMax,YMin,YMax;    private float a0;//线性系数a0private float a1;//线性系数a1    private int pn;  //训练集数据个数     private boolean coefsValid;//类RegressionLine的构造函数     public RegressionLine(){        XMax = 0;        YMax = 0;        pn = 0;        xy = new float[2];        listX = new ArrayList<>();        listY = new ArrayList<>();    }    //类RegressionLine的有参构造函数    public RegressionLine(DataPoint data[]){        pn = 0;        xy = new float[2];        listX = new ArrayList();        listY = new ArrayList();        for(int i = 0;i<data.length;++i){            addDatapoint(data[i]);//添加数据集的方法addDatapoint        }    }    public int getDataPointCount(){        return pn;    }    public float getA0(){        validateCoefficients();        return a0;    }    public float getA1(){        validateCoefficients();        return a1;    }    public double getSumX(){        return sumX;    }    public double getSumY() {        return sumY;    }    public double getSumXX() {        return sumXX;    }    public double getSumYY() {        return sumYY;    }    public double getSumXY() {        return sumXY;    }    public double getXMin() {        return XMin;    }    public double getXMax() {        return XMax;    }    public double getYMax() {        return YMax;    }    public double getYMin() {        return YMin;    }    //添加训练集数据的方法    public void addDatapoint(DataPoint dataPoint){        sumX += dataPoint.x;        sumY += dataPoint.y;        sumXX += dataPoint.x*dataPoint.x;        sumYY += dataPoint.y*dataPoint.y;        sumXY += dataPoint.x*dataPoint.y;        if(dataPoint.x > XMax){            XMax = dataPoint.x;        }        if (dataPoint.y > YMax){            YMax = dataPoint.y;        }        xy[0] = dataPoint.x ;//?        xy[1] = dataPoint.y ;//?        if(dataPoint.x !=0 && dataPoint.y != 0){            System.out.print("("+xy[0]+",");            System.out.println(xy[1]+")");            try{                listX.add(pn,String.valueOf(xy[0]));                listY.add(pn,String.valueOf(xy[1]));            }catch (Exception e){                e.printStackTrace();            }        }        ++pn;        coefsValid = false;    }    //计算预测值y的方法    public float at(float x){        if(pn < 2)            return Float.NaN;        validateCoefficients();        return a0 + a1 * x;    }    //重置此类的方法    public void reset(){        pn = 0;        sumX = sumY = sumXX = sumXY = 0;        coefsValid = false;    }    //计算系数a0,a1的方法    private void validateCoefficients(){        if (coefsValid)            return;        if (pn >= 2){            float xBar = (float)sumX/pn;            float yBar = (float)sumY/pn;            a1 = (float)((pn*sumXY - sumX*sumY)/(pn                  *sumXX - sumX*sumX));            a0 = (yBar - a1*xBar);        }        else {            a0 = a1 = Float.NaN;        }        coefsValid = true;    }    //计算判定系数R^2的方法    public double getR(){        for (int i = 0;i < pn;i++){            float Yi = Float.parseFloat(listY.get(i).toString());            float Y = at(Float.parseFloat(                    listX.get(i).toString()));            float deltaY = Yi - Y;            float deltaY2 = deltaY*deltaY;            sumDeltaY2 += deltaY2;            float deltaY1 = (Yi - (float) (sumY/pn))*(Yi - (float) (sumY/pn)) ;            sst += deltaY1;        }        //sst = sumYY - (sumY*sumY)/pn;        E = 1 - sumDeltaY2/sst;        return round(E,4);    }   //返回经处理过的判定系数的方法    public double round(double v,int scale){        BigDecimal b = new BigDecimal(Double.toString(v));        BigDecimal one = new BigDecimal("1");        return b.divide(one,scale,BigDecimal.ROUND_HALF_UP).floatValue();    }}//测试类import java.util.Scanner;public class LinearRegression {    private static final int MAX_POINTS = 10;//定义最大的训练集数据个数    private double E;    public static void main(String args[]){   //测试主方法        DataPoint[] data = new DataPoint[MAX_POINT];  //创建数据集对象数组data[]//创建线性回归类对象line,并且初始化类        RegressionLine line = new RegressionLine(constructDates(data));//调用printSums方法打印Sum变量        printSums(line);//调用printLine方法并打印线性方程        printLine(line);    }    //构建数据方法    private static DataPoint[] constructDates(DataPoint date[]){        Scanner sc = new Scanner(System.in);        float x,y;        for(int i = 0;i<3;i++){            System.out.println("请输入第"+(i+1)+"个x的值:");            x = sc.nextFloat();            System.out.println("请输入第"+(i+1)+"个y的值:");            y = sc.nextFloat();            date[i] = new DataPoint(x,y);        }        return date;    }    //打印Sum数据方法    private static void printSums(RegressionLine line){        System.out.println("\n数据点个数 n = "+                line.getDataPointCount());        System.out.println("\nSumX = "+line.getSumX());        System.out.println("SumY = "+line.getSumY());        System.out.println("SumXX = "+line.getSumXX());        System.out.println("SumXY = "+line.getSumXY());        System.out.println("SumYY = "+line.getSumYY());    }    //打印回归方程方法    private static void printLine(RegressionLine line){        System.out.println("\n回归线公式:y = "+line.getA1()                +"x + " + line.getA0());        //System.out.println("Hello World!");        System.out.println("误差: R^2 = " + line.getR());    }}

测试结果:

输入测试数据如下

这里写图片描述
这里写图片描述

程序运行结果为:

这里写图片描述

原创粉丝点击