2017/8/31

来源:互联网 发布:淘宝火影忍者手游cdk 编辑:程序博客网 时间:2024/06/14 11:49
def run(data: RDD[LabeledPoint]): NaiveBayesModel = {    val requireNonnegativeValues: Vector => Unit = (v: Vector) => {      val values = v match {        case sv: SparseVector => sv.values        case dv: DenseVector => dv.values      }      if (!values.forall(_ >= 0.0)) {        throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.")      }    }    val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => {      val values = v match {        case sv: SparseVector => sv.values        case dv: DenseVector => dv.values      }      if (!values.forall(v => v == 0.0 || v == 1.0)) {        throw new SparkException(          s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.")      }    }    // Aggregates term frequencies per label.    // TODO: Calling combineByKey and collect creates two stages, we can implement something    // TODO: similar to reduceByKeyLocally to save one stage.    val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)](      createCombiner = (v: Vector) => {        if (modelType == Bernoulli) {          requireZeroOneBernoulliValues(v)        } else {          requireNonnegativeValues(v)        }        (1L, v.copy.toDense)      },      mergeValue = (c: (Long, DenseVector), v: Vector) => {        requireNonnegativeValues(v)        BLAS.axpy(1.0, v, c._2)        (c._1 + 1L, c._2)      },      mergeCombiners = (c1: (Long, DenseVector), c2: (Long, DenseVector)) => {        BLAS.axpy(1.0, c2._2, c1._2)        (c1._1 + c2._1, c1._2)      }    ).collect().sortBy(_._1)    val numLabels = aggregated.length    var numDocuments = 0L    aggregated.foreach { case (_, (n, _)) =>      numDocuments += n    }    val numFeatures = aggregated.head match { case (_, (_, v)) => v.size }    val labels = new Array[Double](numLabels)    val pi = new Array[Double](numLabels)    val theta = Array.fill(numLabels)(new Array[Double](numFeatures))    val piLogDenom = math.log(numDocuments + numLabels * lambda)    var i = 0    aggregated.foreach { case (label, (n, sumTermFreqs)) =>      labels(i) = label      pi(i) = math.log(n + lambda) - piLogDenom      val thetaLogDenom = modelType match {        case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda)        case Bernoulli => math.log(n + 2.0 * lambda)        case _ =>          // This should never happen.          throw new UnknownError(s"Invalid modelType: $modelType.")      }      var j = 0      while (j < numFeatures) {        theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom        j += 1      }      i += 1    }    new NaiveBayesModel(labels, pi, theta, modelType)  }}

make a flag.