spark1.2.0源码MLlib --- 决策树-03

来源:互联网 发布:数据漫游关闭还是开启 编辑:程序博客网 时间:2024/06/06 11:04

本章重点关注树中各节点分裂过程中,如何将相应的数据进行汇总,以便之后计算节点不纯度及信息增益,最终确定分裂的顺序。


首先,从 DecisionTree.findBestSplits() 开始,这个方法代码很长,按照执行顺序来看,代码如下:

    val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {  //节点缓存的情况      input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>        // Construct a nodeStatsAggregators array to hold node aggregate stats,        // each node will have a nodeStatsAggregator        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>            Some(nodeToFeatures(nodeIndex))          }          new DTStatsAggregator(metadata, featuresForNode)        }        // iterator all instances in current partition and update aggregate stats        points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,        // which can be combined with other partition using `reduceByKey`        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator      }    } else {  //节点不缓存的情况      input.mapPartitions { points =>        // Construct a nodeStatsAggregators array to hold node aggregate stats,        // each node will have a nodeStatsAggregator        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>   //每个节点对应一个聚合器,存储该节点下的统计信息          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>            Some(nodeToFeatures(nodeIndex))          }          new DTStatsAggregator(metadata, featuresForNode)  //创建一个节点聚合器        }        // iterator all instances in current partition and update aggregate stats        points.foreach(binSeqOp(nodeStatsAggregators, _))  //统计该节点下的信息(node,features,bins),放入聚合器中        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,        // which can be combined with other partition using `reduceByKey`        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator  //转换为kv对,将和其他分区的聚合器进行合并      }    }
当第一次执行时,只会有一个根节点(节点ID为1),之后,按照树的层级依次递增下去,直到叶子节点为止(或者达到最大的树深度为止)。

继续看关键的一步代码,points.foreach(binSeqOp(nodeStatsAggregators, _)),按分区来聚合,具体代码如下:

   def binSeqOp(        agg: Array[DTStatsAggregator],        baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {      treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>        val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,          bins, metadata.unorderedFeatures)  //根据传递过来的样本值,判断属于哪个节点下        nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) //根据上一步的节点索引,将样本统计信息放入相应的节点聚合器中      }      agg    }

继续跟踪 nodeBinSeqOp() 方法:

    def nodeBinSeqOp(        treeIndex: Int,        nodeInfo: RandomForest.NodeIndexInfo,        agg: Array[DTStatsAggregator],        baggedPoint: BaggedPoint[TreePoint]): Unit = {      if (nodeInfo != null) {        val aggNodeIndex = nodeInfo.nodeIndexInGroup        val featuresForNode = nodeInfo.featureSubset        val instanceWeight = baggedPoint.subsampleWeights(treeIndex)        if (metadata.unorderedFeatures.isEmpty) {          orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)  //特征属性值为有序的情况        } else {          mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,  //有无序的特征属性值            instanceWeight, featuresForNode)        }      }    }

只看有序的那部分:

  private def orderedBinSeqOp(      agg: DTStatsAggregator,      treePoint: TreePoint,      instanceWeight: Double,      featuresForNode: Option[Array[Int]]): Unit = {    val label = treePoint.label    // Iterate over features.    if (featuresForNode.nonEmpty) {  //节点只使用一部分特征      // Use subsampled features      var featureIndexIdx = 0      while (featureIndexIdx < featuresForNode.get.size) {        val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))        agg.update(featureIndexIdx, binIndex, label, instanceWeight)  //更新聚合器的统计信息        featureIndexIdx += 1      }    } else {  //使用所有特征      // Use all features      val numFeatures = agg.metadata.numFeatures      var featureIndex = 0      while (featureIndex < numFeatures) {        val binIndex = treePoint.binnedFeatures(featureIndex)        agg.update(featureIndex, binIndex, label, instanceWeight)  //更新聚合器的统计信息        featureIndex += 1      }    }  }

继续查看 agg.update() 的代码:

  def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {    val i = featureOffsets(featureIndex) + binIndex * statsSize  //该特征索引对应的偏移量位置,在一个统计数组(allStats)中    impurityAggregator.update(allStats, i, label, instanceWeight)  }
其中impurityAggregator有三种实现类:

  val impurityAggregator: ImpurityAggregator = metadata.impurity match {    case Gini => new GiniAggregator(metadata.numClasses)  //gini系数聚合器,分类使用    case Entropy => new EntropyAggregator(metadata.numClasses)  //熵聚合器,分类使用    case Variance => new VarianceAggregator()  //方差聚合器,线性回归使用    case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")  }

看其中一种 EntropyAggregator ,其update()方法如下:

  def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {    if (label >= statsSize) {      throw new IllegalArgumentException(s"EntropyAggregator given label $label" +        s" but requires label < numClasses (= $statsSize).")    }    allStats(offset + label.toInt) += instanceWeight   //最后汇总到该数组中  }

 ***********  The End  *********** 


0 0
原创粉丝点击