spark正则化

来源:互联网 发布:云软件 编辑:程序博客网 时间:2024/05/20 21:45

其他spark源码分析文章
spark源码分析之DecisionTree与GBDT
spark源码分析之随机森林(Random Forest)

这里我们不再对正则化的意义过多阐述,仅仅对理论简单说明和推导,以与spark源码进行对照

1. L2正则化

1.1. 算法

L2正则化就是对权重系数ω的平方项约束,这种情况下的损失函数

C=C0+λ2nww2

C0是原始的损失函数,后面则为正则化项,λ为正则项系数,我们使用梯度下降法,因此梯度的更新公式变成
w>wηC0wηλnw=(1ηλn)wηC0w

1.2. 源码分析

L2正则化在org.apache.spark.mllib.optimization的updater.scala中,实现集中在148-154行,代码中的thisIterStepSize是学习速率,实现了简单的可变速率,regParam就是上文中的λ,gradient为原始梯度,注释中解释了对计算进行了简单的变形

@DeveloperApiclass SquaredL2Updater extends Updater {  override def compute(      weightsOld: Vector,      gradient: Vector,      stepSize: Double,      iter: Int,      regParam: Double): (Vector, Double) = {    // add up both updates from the gradient of the loss (= step) as well as    // the gradient of the regularizer (= regParam * weightsOld)    // w' = w - thisIterStepSize * (gradient + regParam * w)    // w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient    //thisIterStepSize相当于是η,迭代次数越多越小,在迭代后期减少步长有利于收敛    val thisIterStepSize = stepSize / math.sqrt(iter)    val brzWeights: BV[Double] = weightsOld.asBreeze.toDenseVector    //式中第一项    brzWeights :*= (1.0 - thisIterStepSize * regParam)    //第二项,brzweights += -thisIterStepSize*gradient    brzAxpy(-thisIterStepSize, gradient.asBreeze, brzWeights)    val norm = brzNorm(brzWeights, 2.0)    (Vectors.fromBreeze(brzWeights), 0.5 * regParam * norm * norm)  }}

2. L1正则化

2.1. 算法

L1的损失函数

C=C0+λnw|w|

梯度更新
w>wηλnsgn(w)ηC0w

2.2. 源码分析

类L1Updater中,并没有严格按照上式做,按照注释中的说法,这种软阈值方法能够有更好的稀疏性,根据shrinkage与w的范围分三种情况

  • w大于shrinkage,明显此时w大于0,取w-shrinkage
  • w小于-shrinkage,明显此时w小于0,取sgn(w)*(w+shrinkage)
  • w在(-shrinkage, shrinkage),取0

其实第一种和第二种情况就是取w和shrinkage的差,因此可以合并成sgn(w)*(abs(w)-shrinkage);第三种情况,abs(w) < shrinkage,因此可以合并成 sgn(w)*max(0.0, abs(w)-shrinkage)

/** Instead of subgradient of the regularizer, the proximal operator for the * L1 regularization is applied after the gradient step. This is known to * result in better sparsity of the intermediate solution.* The corresponding proximal operator for the L1 norm is the soft-thresholding* function. That is, each weight component is shrunk towards 0 by shrinkageVal.** If w >  shrinkageVal, set weight component to w-shrinkageVal.* If w < -shrinkageVal, set weight component to w+shrinkageVal.* If -shrinkageVal < w < shrinkageVal, set weight component to 0.** Equivalently, set weight component to signum(w) * max(0.0, abs(w) - shrinkageVal)*/@DeveloperApiclass L1Updater extends Updater {  override def compute(      weightsOld: Vector,      gradient: Vector,      stepSize: Double,      iter: Int,      regParam: Double): (Vector, Double) = {    val thisIterStepSize = stepSize / math.sqrt(iter)    // Take gradient step    val brzWeights: BV[Double] = weightsOld.asBreeze.toDenseVector    //第二项    brzAxpy(-thisIterStepSize, gradient.asBreeze, brzWeights)    // Apply proximal operator (soft thresholding)    val shrinkageVal = regParam * thisIterStepSize    var i = 0    val len = brzWeights.length    while (i < len) {      val wi = brzWeights(i)      brzWeights(i) = signum(wi) * max(0.0, abs(wi) - shrinkageVal)      i += 1    }    (Vectors.fromBreeze(brzWeights), brzNorm(brzWeights, 1.0) * regParam)  }}

其中个参数的意义与L2类似,原理如上述,没什么太多可说的

原创粉丝点击