Spark 线性回归

来源:互联网 发布:linux traceroute 编辑:程序博客网 时间:2024/06/17 09:47

一、建立回归方程

        回归是应用于预测输出变量为连续变化的场景,就像广为流传的房价与面积的关系,如果仅仅是一个因变量和一个自变量,那叫一元线性回归,如果是多个自变量一个因变量就叫多元线性回归。以下图为例:



                                          图片来自http://blog.csdn.net/sunbow0/article/details/45539255

由此可得到线性方程:

代表参数,代表房屋面积和房间个数,我们可以令为零,这样就构成了三元线性回归方程,可以把他们以向量的形式表示如下:

,每个参数代表每一个属性的权重,我们称X为特征或属性,为权重。

        接下来就是求解参数,使得它能尽可能准确的表达我们要预测的值,在训练集中,我们是知道结果y的,所以我们可以通过最小化预测值和真实值之间的差值来获得合适的

二、求解参数

        误差方程为:或者叫做损失函数,m为样本个数。

        1、梯度下降法:

        即求得上式的导数,使其沿着梯度下降的方向走,求得全局最小值。  求导公式为假设有一个样本即m=1:

                                           

        对的更新方式为    冒号代表对每一个的更新,为更新步长,也就是梯度下降的速度,称为学习速率,过大,则有可能越过最小值,过小,则迭代次数过多,函数收敛的慢,j为第几个参数。将带入的更新方程中得到,。当样本数大于1时,的更新方程为:

                                          

即对所有样本的误差先求和,再对没一个参数进行求导,获得梯度,更新每一个参数。这样方法也称为批量梯度下降,这种方法计算量大,收敛的较慢。

         2、随机梯度下降法:

         随机梯度下降法是先利用一个样本对所有参数更新,再计算是否收敛,若不收敛,再读取下一个样本进行更新,如此循环下去,直至收敛。当数据量很大的时候,可能只读取了一部分数据就已经收敛,节省了计算量。

计算方式为:

loop{                 for i=1 to m  遍历样本                 {                         for j=1 to n 遍历参数                         {                                  更新参数                         }                        }         }
        但是,相较于批量梯度下降算法而言,随机梯度下降算法使得J(Θ)趋近于最小值的速度更快,但是有可能造成永远不可能收敛于最小值,有可能一直会在最小值周围震荡,但是实践中,大部分值都能够接近于最小值,效果也都还不错。

spark的线性回归程序分析:

      

package SparkML;import org.apache.log4j.Level;import org.apache.log4j.Logger;import org.apache.spark.ml.linalg.Vectors;import org.apache.spark.ml.regression.LinearRegression;import org.apache.spark.ml.regression.LinearRegressionModel;import org.apache.spark.ml.regression.LinearRegressionTrainingSummary;import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Row;import org.apache.spark.sql.SparkSession;public class mlLinearRegression {public static void main(String[] args){SparkSession spark = SparkSession.builder().master("local")                                         //设置本地运行.appName("ML Linera Regression")         //设置程序名称.getOrCreate();                                     Logger.getLogger("org.apache.spark").setLevel(Level.WARN); Dataset<Row> data = spark.read().format("libsvm")                                       //读入数据的格式,该格式可以自己写程序制作.load("/home/greg/newlibsvm.txt");Dataset<Row>[] sampleData = data.randomSplit(new double[]{0.7,0.3}, 11L); //数据随机分成两份Dataset<Row> train = sampleData[0];                      //训练集Dataset<Row> test = sampleData[1];                        //测试集data.select("features").show();                                  //打印出特征(属性)System.out.println(data.count());                              //共有多少条数据LinearRegression lr = new LinearRegression().setMaxIter(21)                                          //设置最大迭代次数.setRegParam(0.3)                                      //正则化参数.setElasticNetParam(1);//L1,L2混合正则化(aL1+(1-a)L2)LinearRegressionModel lrModel = lr.fit(train);           //开始训练// Print the coefficients and intercept for linear regression.System.out.println("Coefficients: "  + lrModel.coefficients() + " Intercept: " + lrModel.intercept());           //输出参数// Summarize the model over the training set and print out some metrics.LinearRegressionTrainingSummary trainingSummary = lrModel.summary();         System.out.println("numIterations: " + trainingSummary.totalIterations());
                //每次迭代的(loss+regulation)
 System.out.println("objectiveHistory: " + Vectors.dense(trainingSummary.objectiveHistory()));
//训练集的预测值和实际值的差
                //训练集的误差(label-pred)

trainingSummary.residuals().show(); 
//均方根误差System.out.println("RMSE: " + trainingSummary.rootMeanSquaredError());System.out.println("r2: " + trainingSummary.r2()); //正则化参数Dataset<Row> prediction = lrModel.transform(test);prediction.selectExpr("label","features","prediction").show(); //输出测试集的label和预测值spark.stop();}}



0 0
原创粉丝点击