【spark source】Spark LinearRegression源码解读

来源:互联网 发布:数据库基础pdf 编辑:程序博客网 时间:2024/05/16 05:03

:org.apache.spark.mllib.regression.RegressionModel

定义线性回归模型的predict接口

:org.apache.spark.mllib.regression.impl.GLMRegressionModel

从文件中加载Model,或保存Model到文件中

:org.apache.spark.mllib.pmml.PMMLExportable

把模型转换为Predictive Model Markup Language (PMML)

:org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.GeneralizedLinearModel

定义weights和intercept,同时实现批量predict方法

:org.apache.spark.mllib.regression.LinearRegression.LinearRegressionModel   extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable  with Saveable with PMMLExportable 

线性回归Model,实现具体predict方法:weights*data+intercept

:org.apache.spark.mllib.optimization.Optimizer

定义凸优化算法接口optimize

:org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm

定义Template优化方法run(),具体优化策略由子类的Optimizer实现

-判断RDD是否缓存

-根据验证方法验证数据

-对数据进行FeatureScaling(withStd = true, withMean = false)

-训练模型

-rescale模型的weights和intercept

-去掉数据缓存

-返回模型

:org.apache.spark.mllib.optimization.Gradient

定义梯度和loss计算接口,子类实现具体的梯度计算

:org.apache.spark.mllib.optimization.Gradient.LeastSquaresGradient

计算最小二乘法的梯度:L = 1/2n ||A weights-y||^2

:org.apache.spark.mllib.optimization.Updater

更新权值weights和Regularization value接口

:org.apache.spark.mllib.optimization.Updater.SimpleUpdater

不带regularization更新权值

:org.apache.spark.mllib.optimization.GradientDescent

mini-batch梯度下降法实现

-每次采样一批次的data

-汇总计算梯度

-更新权值

:org.apache.spark.mllib.regression.LinearRegression.LinearRegressionWithSGD extends GeneralizedLinearAlgorithm[LinearRegressionModel]

不带regularization的SGD实现,优化function计算mean squared error: f(weights) = 1/n ||A weights-y||^2^



0 0