ExpectationSum
来源:互联网 发布:建设施工安全网络平台 编辑:程序博客网 时间:2024/06/03 05:33
logLikelihood:似然函数
weights:每个类的权重
means:每个类的均值
sigmas:每个类的covariance matrix
package org.apache.spark.mllib.clusteringimport breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Vector => BV}import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, BLAS}import org.apache.spark.mllib.stat.distribution.MultivariateGaussianimport org.apache.spark.mllib.util.MLUtils/** * Created by fhqplzj on 16-7-29 at 下午3:29. */class MyExpectationSum(var logLikelihood: Double, val weights: Array[Double], val means: Array[BDV[Double]], val sigmas: Array[BDM[Double]]) extends Serializable { def +=(x: MyExpectationSum) = { for (i <- weights.indices) { weights(i) += x.weights(i) means(i) += x.means(i) sigmas(i) += x.sigmas(i) } logLikelihood += x.logLikelihood this }}object MyExpectationSum { def zero(k: Int, d: Int) = { new MyExpectationSum(0, Array.fill(k)(0), Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BDM.zeros(d, d))) } def add(weights: Array[Double], dists: Array[MultivariateGaussian])(sums: MyExpectationSum, x: BV[Double]) = { val p = weights.zip(dists).map { case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(x) } val pSum = p.sum sums.logLikelihood += math.log(pSum) for (i <- p.indices) { p(i) /= pSum sums.weights(i) += p(i) sums.means(i) += p(i) * x BLAS.syr(p(i), Vectors.fromBreeze(x), Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix]) } sums }}
0 0