结合源码分析Spark中的Accuracy(准确率), Precision(精确率), 和F1-Measure

来源:互联网 发布:类似蝰蛇音效的软件 编辑:程序博客网 时间:2024/05/17 08:15

例子

某大学一个系,总共100人,其中男90人,女10人,现在根据每个人的特征,预测性别

Accuracy(准确率)

Accuracy=

计算

由于我知道男生远多于女生,所以我完全无视特征,直接预测所有人都是男生 
我预测所的人都是男生,而实际有90个男生,所以 
预测正确的数量 = 90 
需要预测的总数 = 100 
Accuracy = 90 / 100 = 90%

问题

在男女比例严重不均匀的情况下,我只要预测全是男生,就能获得极高的Accuracy。 
所以在正负样本严重不均匀的情况下,Accuracy指标失效

Precision(精确率), Recall(召回率)

.实际为真实际为假预测为真TPFP预测为假FNTN
# 前面的T和F,代表预测是否正确# 后面的P和N,代表预测是真还是假TP:预测为真,正确了FP:预测为真,结果错了TN:预测为假,正确了FN:预测为假,结果错了
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

Precision=TPTP+FP=

Recall=TPTP+FN=

计算

注意:在正负样本严重不均匀的情况下,正样本必须是数量少的那一类。这里女生是正样本。是不是女生,是则预测为真,不是则预测为假。

  • 如果没有预测为真的情况,计算时分母会为0,所以做了调整,也容易比较Accuracy和Precision, Recall的区别
.实际为真实际为假预测为真10预测为假1089

Accuracy = (1 + 89)/ (1 + 0 + 10 + 89) = 90 / 100 = 0.9 
Precision = 1 / 1 + 0 = 1 
Recall = 1 / 1 + 10 = 0.09090909

注意:为方便与后面Spark的计算结果对比,无限循环小数,我们不做四合五入

问题

虽然我们稍微调整了预测结果,但是Accuracy依然无法反应预测结果。

而Precision在这里达到了1,但是Recall却极低。因此Precision,Recall的组合能够反应我们的预测效果不佳。

但是Precision,Recall在对比的时候会出现问题,比如一个模型的Precision是0.9,Recall是0.19,那么与上面的1和0.0909对比,哪个模型更好呢?

所以我们需要一个指标,能够综合的反应Precision和Recall

F1-Measure

F1值就是Precision和Recall的调和均值

1F1=1Precision+1Recall

整理后:

F1=2×Precision×RecallPrecision+Recall

计算

计算上面提到的对比情况

F1 = (2 * 1 * 0.09090909) / 1 + 0.09090909 = 0.1666666 
F1 = (2 * 0.9 * 0.19) / 0.9 + 0.19 = 0.3137

很显然后一种更好

调整Precision, Recall的权重

Fa=(a2+1)×Precision×Recalla2×(Precision+Recall)

当a等于1时,Precision,Recall各占50%,就是F1-Measure了

Spark源码分析

Spark中API计算Precision,Recall,F1

用Spark API计算出上面我们手工计算出的值

import org.apache.spark.mllib.evaluation.BinaryClassificationMetricsimport org.apache.spark.{SparkConf, SparkContext}object Test {  def main(args: Array[String]) {    val conf = new SparkConf().setAppName("test").setMaster("local") // 调试的时候一定不要用local[*]    val sc = new SparkContext(conf)    sc.setLogLevel("ERROR")    // 我们先构造一个与上文一样的数据    /**      *         实际为真  实际为假      * 预测为真   1        0      * 预测为假   10       89      */    // 左边是预测为真的概率,右边是真实值    val TP = Array((1.0, 1.0)) // 预测为真,实际为真    val TN = new Array[(Double, Double)](89) // 预测为假, 实际为假    for (i <- TN.indices) {      TN(i) = (0.0, 0.0)    }    val FP = new Array[(Double, Double)](10) // 预测为假, 实际为真    for (i <- FP.indices) {      FP(i) = (0.0, 1)    }    val all = TP ++ TN ++ FP    val scoreAndLabels = sc.parallelize(all)    // 打印观察数据    //    scoreAndLabels.collect().foreach(println)    //    println(scoreAndLabels.count())    // 到这里,我们构造了一个与上文例子一样的数据    val metrics = new BinaryClassificationMetrics(scoreAndLabels)    // 下面计算的值,我们先只看右边的数,它表示计算的precision,recall,F1等    // 左边是Threshold,后面会细说    /**      * (1.0,1.0) // precision跟我们自己之前计算的一样      * (0.0,0.11) // 这是什么?先不管      */    metrics.precisionByThreshold().collect().foreach(println)    println("---")    /**      * (1.0,0.09090909090909091) // recall跟我们自己之前计算的一样      * (0.0,1.0) // 先忽略      */    metrics.recallByThreshold().collect().foreach(println)    println("---")    /**      * (1.0,0.16666666666666669) // f1跟我们自己之前计算的一样      * (0.0,0.19819819819819817) // 先忽略      */    metrics.fMeasureByThreshold().collect().foreach(println)  }}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64

至此,我们用Spark API计算出了各个值。但是有几个疑问

  • 无论是precision,recall,还是fMeasure,后面都跟一个ByThreshold,为什么?
  • 这三个指标,不应该是一个数嘛,为什么返回一个RDD,里面包含一堆数?

要弄清楚,就出要知道它们是怎么计算出来的

计算分析(以Precision为例)

  • 从代码的角度,一步步跟踪到Precision的计算公式,公式找到了值也就算出来了
  • 从数据的角度,你的输入数据是怎么一步步到结果的

代码角度

# 类声明# scoreAndLabels是一个RDD,存放预测为真的概率和真实值# numBins,先忽略class BinaryClassificationMetrics (val scoreAndLabels: RDD[(Double, Double)], val numBins: Int)
  • 1
  • 2
  • 3
  • 4

调用BinaryClassificationMetrics的precisionByThreshold方法计算,precision

new BinaryClassificationMetrics(scoreAndLabels).precisionByThreshold()
  • 1

跟踪进入precisionByThreshold方法

def precisionByThreshold(): RDD[(Double, Double)] = createCurve(Precision)# 调用了createCurve(Precision)# precisionByThreshold返回的RDD,就是这个createCurve方法的返回值# 两个问题# createCurve是什么?# 参数Precision又是什么?
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

跟踪进入createCurve方法

/** Creates a curve of (threshold, metric). */private def createCurve(y: BinaryClassificationMetricComputer): RDD[(Double, Double)] = {    // confusions肯定是一个RDD,因为它调用了map,然后就作为返回值返回了    // 所以confusions是关键,对它做变换,就能得到结果    confusions.map { case (s, c) =>      // precisionByThreshold返回的RDD,左边是threshold,右边是precision      // 所以这里的s,就是threshold      // y(c),就是precision      // y是传入的参数,也就是createCurve(Precision)中的,Precision      // 下面就先看看Precision是什么      (s, y(c))    }}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

跟踪进入Precision

// 上文中的y(c),也就是Precision(c),这语法,自然是调用apply方法/** Precision. Defined as 1.0 when there are no positive examples. */private[evaluation] object Precision extends BinaryClassificationMetricComputer {  override def apply(c: BinaryConfusionMatrix): Double = {    // 看名字numTruePositives,就是TP的数量嘛    // totalPositives = TP + FP    val totalPositives = c.numTruePositives + c.numFalsePositives    // totalPositives为0,也就一个真都没预测    if (totalPositives == 0) {      // 0 / 0,会出错,这里是直接返回1      1.0    } else {      // 公式出现      // Precision = TP / (TP + FP)      c.numTruePositives.toDouble / totalPositives    }  }}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

到这里找到了Precision的计算公式,但是上面提到的两个疑问,还没有解决,Threshold怎么回事,返回RDD干嘛?

但是通过上面的分析,我们找到了线索,confusions这个通过变换就能出结果的变量,也许就是答案。

数据角度

跟踪到confusions的声明

private lazy val (    cumulativeCounts: RDD[(Double, BinaryLabelCounter)],    confusions: RDD[(Double, BinaryConfusionMatrix)]) = {    // ... 省略了60行左右    (cumulativeCounts, confusions)}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

这60行里做了什么,我们拷贝出来,一步步分析

import org.apache.spark.mllib.evaluation.BinaryClassificationMetricsimport org.apache.spark.rdd.RDDimport org.apache.spark.{SparkConf, SparkContext}object Test {  def main(args: Array[String]) {    val conf = new SparkConf().setAppName("test").setMaster("local") // 调试的时候一定不要用local[*]    val sc = new SparkContext(conf)    sc.setLogLevel("ERROR")    val TP = Array((1.0, 1.0))    val TN = new Array[(Double, Double)](89)    for (i <- TN.indices) {      TN(i) = (0.0, 0.0)    }    /**      * *******这里改了********这里改了********这里改了*****      */    // 从10改成了5,有5个样本有60%的概率是真的;另外5个设置成了40%,在下面    val FP1 = new Array[(Double, Double)](5)    for (i <- FP1.indices) {      FP1(i) = (0.6, 1)    }    val FP2 = new Array[(Double, Double)](5) // 有5个样本有40%的概率是真的    for (i <- FP2.indices) {      FP2(i) = (0.4, 1)    }    val all = TP ++ TN ++ FP1 ++ FP2    val scoreAndLabels = sc.parallelize(all, 2) // 调整并行度为2,后面会说,为什么要调整    // 打印观察数据    scoreAndLabels.collect().foreach(println)    val metrics = new BinaryClassificationMetrics(scoreAndLabels)    // 先看下调整后的结果    // 左边一列多了0.6,和0.4,猜的话,应该是因为上面的概率我们添加了0.6和0.4    // 后面会说,具体是怎么出来的    /**      * (1.0,1.0) // 当Threshold为1时,precision是1      * (0.6,1.0) // 当Threshold为0.6时,precision还是1.0      * (0.4,1.0) // 以此类推      * (0.0,0.11)      */    println("-- precisionByThreshold --")    metrics.precisionByThreshold().collect().foreach(println)    /**      * (1.0,0.09090909090909091)      * (0.6,0.5454545454545454)      * (0.4,1.0)      * (0.0,1.0)      */    println("-- recallByThreshold --")    metrics.recallByThreshold().collect().foreach(println)    /**      * (1.0,0.16666666666666669)      * (0.6,0.7058823529411764)      * (0.4,1.0)      * (0.0,0.19819819819819817)      */    println("--  fMeasureByThreshold --")    metrics.fMeasureByThreshold().collect().foreach(println)    // 下面以Precision的计算为例    // 下面的代码是初始化confusions的代码, 在BinaryClassificationMetrics类中,Spark 1.6.1版本的149行开始    // 1. 以预测的概率为key,计算在这个概率下,有多少个;比如:0.6这个概率,出现了多少个(0.6, 1)或0.6, 0)    /**      * (1.0,{numPos: 1, numNeg: 0}) // 1.0,只有一个      * (0.6,{numPos: 5, numNeg: 0}) // 0.6,5个,上面我们修改的      * (0.4,{numPos: 5, numNeg: 0}) // 0.4,同样是5个      * (0.0,{numPos: 0, numNeg: 89}) // 0.0, 89个      */    println("-- binnedCounts --")    val binnedCounts = scoreAndLabels.combineByKey(      // BinaryLabelCounter用于存储累加的numPositives和numNegatives      // 先说下label是什么,scoreAndLabels中右边那一列,只可能是0或1, 是真实值      // BinaryLabelCounter中判断是Positive还是Negatives,是通过label,而不是你自己预测的概率,不是左边那一列      // label > 0.5 为Positive      createCombiner = (label: Double) => new BinaryLabelCounter(0L, 0L) += label,      mergeValue = (c: BinaryLabelCounter, label: Double) => c += label,      mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2    ).sortByKey(ascending = false)    binnedCounts.collect().foreach(println)    println("-- agg --")    // agg是一个数组,collect返回一个数组    // 前面设置了Partition为2,所以这里会有两条数据    // 计算每个Partition中numPos, numNeg的总和    /**      * {numPos: 6, numNeg: 0}      * {numPos: 5, numNeg: 89}      */    val agg = binnedCounts.values.mapPartitions { iter =>      val agg = new BinaryLabelCounter()      iter.foreach(agg += _)      Iterator(agg)    }.collect()    agg.foreach(println)    // partitionwiseCumulativeCounts的长度是Partition数量加1    // partitionwiseCumulativeCounts的每一行是每个Partition的初始numPos, numNeg数量; 这点很重要, 后面会用到    /**      * {numPos: 0, numNeg: 0} // 第一个Partition的初始, 都是0,      * {numPos: 6, numNeg: 0} // 第一个Partition累加后, 等于第二个Partition的初始值;同样可以表明第一个Partition中有6个是Positive      * {numPos: 11, numNeg: 89} // 最后一个位置,就是正负样本的总数; 一共只有两个Partition,都累加起来自然就是总和。      */    println("-- partitionwiseCumulativeCounts --")    val partitionwiseCumulativeCounts =    // 创建一个新的BinaryLabelCounter,然后把agg中的值,从左往右,加一遍      agg.scanLeft(new BinaryLabelCounter())(        (agg: BinaryLabelCounter, c: BinaryLabelCounter) => agg.clone() += c)    partitionwiseCumulativeCounts.foreach(println)    // 打印正负样本总数    val totalCount = partitionwiseCumulativeCounts.last    println(s"Total counts: $totalCount")    // 打印Partition的数量    println("getNumPartitions = " + binnedCounts.getNumPartitions)    // binnedCounts    // binnedCounts经过mapPartitionsWithIndex后就变成了cumulativeCounts    // 先看cumulativeCounts是怎么算出来, 跟下面那组cumulativeCounts数据的结合起来看    /**      * (1.0,{numPos: 1, numNeg: 0}) // 第一行是一样的      * (0.6,{numPos: 5, numNeg: 0}) // 第一行加上第二上,就是cumulativeCounts的第二行      * (0.4,{numPos: 5, numNeg: 0}) // 前三行相加,就是cumulativeCounts的第三行      * (0.0,{numPos: 0, numNeg: 89}) // 以此类推,前四行相加,就是cumulativeCounts的第四行      */    // cumulativeCounts    // 那cumulativeCounts的这些数是什么意思呢?    /**      * (1.0,{numPos: 1, numNeg: 0}) // 当取Threshold为1时,有一个样本,我预测为真      * (0.6,{numPos: 6, numNeg: 0}) // 当取Threshold为0.6时,有6个样本,我预测为真      * (0.4,{numPos: 11, numNeg: 0}) // 以此类推      * (0.0,{numPos: 11, numNeg: 89})      */    println("-- cumulativeCounts --")    // 代码是怎么实现的, 数据可是在RDD上    // 首先binnedCounts是sortByKey排过序的,每个Partitions中是有序的    // 再加上Partition的Index, 和之前的计算的partitionwiseCumulativeCounts, 就能够计算出来    /**      * partitionwiseCumulativeCounts      * {numPos: 0, numNeg: 0} index为0的Partition, 刚开始时, numPos和numNeg都是0      * {numPos: 6, numNeg: 0} 经过index为0的Partition累加后, index为1的Partition, 刚开始时, numPos为6      * {numPos: 11, numNeg: 89}      */    val cumulativeCounts = binnedCounts.mapPartitionsWithIndex(      (index: Int, iter: Iterator[(Double, BinaryLabelCounter)]) => {        val cumCount = partitionwiseCumulativeCounts(index)        iter.map { case (score, c) =>          // index为0时, cumCount为{numPos: 0, numNeg: 0}; 也就是第一个Partition, 刚开始时, numPos和numNeg都是0          // 第一个过来的是, (1.0,{numPos: 1, numNeg: 0}), 经过cumCount += c, 变成了(1.0,{numPos: 1, numNeg: 0})          // 第二个过来的是, (0.6,{numPos: 5, numNeg: 0}), 经过cumCount += c, (0.6,{numPos: 6, numNeg: 0})          // index为1时, cumCount为{numPos: 6, numNeg: 0}; 也就是第二个Partition, 刚开始时, numPos为6          // 第一个过来的是, (0.4,{numPos: 5, numNeg: 0}), 经过cumCount += c, 变成了(0.4,{numPos: 11, numNeg: 0})          // 第二个过来的是, (0.0,{numPos: 0, numNeg: 89}), 经过cumCount += c, 变成了(0.0,{numPos: 11, numNeg: 89})          cumCount += c          (score, cumCount.clone())        }        // preservesPartitioning = true, mapPartitionsWithIndex算子计算过程中,不能修改key      }, preservesPartitioning = true)    cumulativeCounts.collect().foreach(println)    /**      * BinaryConfusionMatrixImpl({numPos: 1, numNeg: 0},{numPos: 11, numNeg: 89})      * 这个矩阵应该转换成下面这种形式来看      *      *          实际为真  实际为假      * 预测为真   1        0      * 预测为假   11-1     89-0      *      * 所以当Threshold不断变化时,矩阵也在不断变化,因此在precision在不断变化      *      * (1.0,BinaryConfusionMatrixImpl({numPos: 1, numNeg: 0},{numPos: 11, numNeg: 89}))      * (0.6,BinaryConfusionMatrixImpl({numPos: 6, numNeg: 0},{numPos: 11, numNeg: 89}))      * (0.4,BinaryConfusionMatrixImpl({numPos: 11, numNeg: 0},{numPos: 11, numNeg: 89}))      * (0.0,BinaryConfusionMatrixImpl({numPos: 11, numNeg: 89},{numPos: 11, numNeg: 89}))      */    println("-- confusions --")    val confusions = cumulativeCounts.map { case (score, cumCount) =>      (score, BinaryConfusionMatrixImpl(cumCount, totalCount).asInstanceOf[BinaryConfusionMatrix])    }    confusions.collect().foreach(println)    println("-- precision --")    def createCurve(y: BinaryClassificationMetricComputer): RDD[(Double, Double)] = {      confusions.map { case (s, c) =>        (s, y(c))      }    }    createCurve(Precision).collect().foreach(println)    sc.stop()  }  object Precision extends BinaryClassificationMetricComputer {    override def apply(c: BinaryConfusionMatrix): Double = {      val totalPositives = c.numTruePositives + c.numFalsePositives      if (totalPositives == 0) {        1.0      } else {        c.numTruePositives.toDouble / totalPositives      }    }  }  trait BinaryClassificationMetricComputer extends Serializable {    def apply(c: BinaryConfusionMatrix): Double  }  class BinaryLabelCounter(var numPositives: Long = 0L, var numNegatives: Long = 0L) extends Serializable {    /** Processes a label. */    def +=(label: Double): BinaryLabelCounter = {      // Though we assume 1.0 for positive and 0.0 for negative, the following check will handle      // -1.0 for negative as well.      if (label > 0.5) numPositives += 1L else numNegatives += 1L      this    }    /** Merges another counter. */    def +=(other: BinaryLabelCounter): BinaryLabelCounter = {      numPositives += other.numPositives      numNegatives += other.numNegatives      this    }    override def clone: BinaryLabelCounter = {      new BinaryLabelCounter(numPositives, numNegatives)    }    override def toString: String = s"{numPos: $numPositives, numNeg: $numNegatives}"  }  private case class BinaryConfusionMatrixImpl(count: BinaryLabelCounter, totalCount: BinaryLabelCounter) extends BinaryConfusionMatrix {    /** number of true positives */    override def numTruePositives: Long = count.numPositives    /** number of false positives */    override def numFalsePositives: Long = count.numNegatives    /** number of false negatives */    override def numFalseNegatives: Long = totalCount.numPositives - count.numPositives    /** number of true negatives */    override def numTrueNegatives: Long = totalCount.numNegatives - count.numNegatives    /** number of positives */    override def numPositives: Long = totalCount.numPositives    /** number of negatives */    override def numNegatives: Long = totalCount.numNegatives  }  private trait BinaryConfusionMatrix {    /** number of true positives */    def numTruePositives: Long    /** number of false positives */    def numFalsePositives: Long    /** number of false negatives */    def numFalseNegatives: Long    /** number of true negatives */    def numTrueNegatives: Long    /** number of positives */    def numPositives: Long = numTruePositives + numFalseNegatives    /** number of negatives */    def numNegatives: Long = numFalsePositives + numTrueNegatives  }}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296

到此分析完了Precision的计算过程. 
那么对于Threshold和为什么返回RDD, 我们应该怎么理解呢? 
precisionByThreshold能够让我了解, 随着Threshold的变化, precision是如何变化的

选择Threshold

import com.leo.tianchi.test.Run.BinaryLabelCounterimport org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression}import org.apache.spark.mllib.linalg.Vectorimport org.apache.spark.sql.functions.maximport org.apache.spark.sql.{Row, SQLContext}import org.apache.spark.{SparkConf, SparkContext}object Test {  def main(args: Array[String]) {    val conf = new SparkConf().setAppName("test").setMaster("local") // 调试的时候一定不要用local[*]    val sc = new SparkContext(conf)    val sqlContext = new SQLContext(sc)    sc.setLogLevel("ERROR")    import sqlContext.implicits._    // 改成自己的Spark家目录    val training = sqlContext.read.format("libsvm").load("/usr/local/spark/spark-1.6.1-bin-hadoop2.6/data/mllib/sample_libsvm_data.txt")    val lr = new LogisticRegression()      .setMaxIter(100)      .setRegParam(0.3)      .setElasticNetParam(0.8)    val lrModel = lr.fit(training)    val binarySummary = lrModel.summary.asInstanceOf[BinaryLogisticRegressionSummary]    val scoreAndLabels = binarySummary.predictions.select("probability", "label").map {      case Row(score: Vector, label: Double) => (score(1), label)    }    scoreAndLabels.collect().foreach(println)    println("-- binnedCounts --")    /**      * 下面抽取的部分数据做分析      *      * 左边一列是Threshold, 由大到小排列      * 通过观察发现, 刚刚开始时, numPos总是大于0的, 而numNeg总是等于0的; 也就是说当预测为真的概率很高时, 真实值也是真      * 到了中间, 我们预测的概率变化不是很大, 但是真实值却摇摆不定; 这很容易理解, 当我们只有50%的把握时, 比如扔硬币, 就是会一会儿正一会儿反      * 最后, 就都是numNeg大于0, numPos等于零      * (0.7858977614108025,{numPos: 1, numNeg: 0})      * (0.6647454962187126,{numPos: 1, numNeg: 0})      * (0.5408778070820107,{numPos: 1, numNeg: 0})      *  ...省略中数据      * (0.3975829487342493,{numPos: 0, numNeg: 1})      * (0.35639781721605096,{numPos: 1, numNeg: 0})      * (0.33923223159640786,{numPos: 0, numNeg: 1})      * ...省略中数据      * (0.32419460909076375,{numPos: 0, numNeg: 3})      * (0.31989741144442924,{numPos: 0, numNeg: 1})      * (0.3170955715164504,{numPos: 0, numNeg: 1})      */    val binnedCounts = scoreAndLabels.combineByKey(      createCombiner = (label: Double) => new BinaryLabelCounter(0L, 0L) += label,      mergeValue = (c: BinaryLabelCounter, label: Double) => c += label,      mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2    ).sortByKey(ascending = false)    binnedCounts.collect.foreach(println)    binarySummary.precisionByThreshold.show(100000)    binarySummary.recallByThreshold.show(100000)    val fMeasure = binarySummary.fMeasureByThreshold    fMeasure.show(100000)    /**      * 如果要选择Threshold, 这三个指标中, 自然F1最为合适      * 求出最大的F1, 对应的threshold就是最佳的threshold      */    val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)    val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure)      .select("threshold").head().getDouble(0)    println(bestThreshold)  }}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77

参考

准确率(Accuracy), 精确率(Precision), 召回率(Recall)和F1-Measure

阅读全文
0 0
原创粉丝点击