Apache Math Linear Regression

来源:互联网 发布:淘宝售后可以修改几次 编辑:程序博客网 时间:2024/06/12 21:24
package com;import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;import org.apache.commons.math3.stat.regression.SimpleRegression;public class MathLinearRegression {public static void main(String[] args) {simpleRegression();multipleRegression();}private static void multipleRegression() {System.out.println("multipleRegression");final OLSMultipleLinearRegression regression2 = new OLSMultipleLinearRegression();double[] y = { 2, 3, 4, 5, 6 };double[][] x2 = { { 1 }, { 2 }, { 3 }, { 4 }, { 5 }, };regression2.newSampleData(y, x2);double[] beta = regression2.estimateRegressionParameters();for (double d : beta) {System.out.println("D: " + d);}System.out.println("prediction for 1.5 = " + predict(new double[] { 1.5, 1 }, beta));}private static double predict(double[] data, double[] beta) {double result = 0;for (int i = 0; i < data.length; i++) {result += data[i] * beta[i];}return result;}private static void simpleRegression() {System.out.println("simpleRegression");// creating regression object, passing true to have intercept termSimpleRegression simpleRegression = new SimpleRegression(true);// passing data to the model// model will be fitted automatically by the classsimpleRegression.addData(new double[][] { { 1, 2 }, { 2, 3 }, { 3, 4 },{ 4, 5 }, { 5, 6 } });// querying for model parametersSystem.out.println("slope = " + simpleRegression.getSlope());System.out.println("intercept = " + simpleRegression.getIntercept());// trying to run model for unknown dataSystem.out.println("prediction for 1.5 = "+ simpleRegression.predict(1.5));}}

simpleRegressionslope = 1.0intercept = 1.0prediction for 1.5 = 2.5multipleRegressionD: 1.0D: 1.0prediction for 1.5 = 2.5


0 0
原创粉丝点击