spark mllib源码分析之随机森林(Random Forest)(四)

来源:互联网 发布:网盘搜索引擎 知乎 编辑:程序博客网 时间:2024/06/08 04:03

spark源码分析之随机森林(Random Forest)(一)
spark源码分析之随机森林(Random Forest)(二)
spark源码分析之随机森林(Random Forest)(三)
spark源码分析之随机森林(Random Forest)(五)

6.4. node分裂

逻辑主要在DecisionTree.findBestSplits函数中,是RF训练最核心的部分

DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,        treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)

6.4.1. 数据统计

数据统计分成两部分,先在各个partition上分别统计,再累积各partition成全局统计。

6.4.1.1. 取出node的特征子集
val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)

取出各node的特征子集,如果不需要抽样则为None;否则返回Map[Int, Array[Int]],其实就是将之前treeToNodeToIndexInfo中的NodeIndexInfo转换为map结构,将其作为广播变量nodeToFeaturesBc。

6.4.1.2. 分区统计

一系列函数的调用链,我们逐层分析

val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {      input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>        // Construct a nodeStatsAggregators array to hold node aggregate stats,        // each node will have a nodeStatsAggregator        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>            Some(nodeToFeatures(nodeIndex))          }          new DTStatsAggregator(metadata, featuresForNode)        }        // iterator all instances in current partition and update aggregate stats        points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,        // which can be combined with other partition using `reduceByKey`        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator      }    } else {      input.mapPartitions { points =>        // Construct a nodeStatsAggregators array to hold node aggregate stats,        // each node will have a nodeStatsAggregator        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>            Some(nodeToFeatures(nodeIndex))          }          new DTStatsAggregator(metadata, featuresForNode)        }        // iterator all instances in current partition and update aggregate stats        points.foreach(binSeqOp(nodeStatsAggregators, _))        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,        // which can be combined with other partition using `reduceByKey`        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator      }    }

首先对每个partition构造一个DTStatsAggregator数组,长度是node的个数,注意这里实际使用的是数组,node怎样与自己的aggregator的对应?前面我们提到NodeIndexInfo的第一个成员是groupIndex,其值就是node的次序,和这里aggregator数组index其实是对应的,也就是说可以从NodeIndexInfo中取得groupIndex,然后作为数组index取得对应node的agg。DTStatsAggregator的入参是metadata和每个node的特征子集。然后将每个点统计到DTStatsAggregator中,其中调用了内部函数binSeqOp,

 /**     * Performs a sequential aggregation over a partition.     *     * Each data point contributes to one node. For each feature,     * the aggregate sufficient statistics are updated for the relevant bins.     *     * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for     *             each (node, feature, bin).     * @param baggedPoint   Data point being aggregated.     * @return  agg     */    def binSeqOp(        agg: Array[DTStatsAggregator],        baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {    //对每个node      treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>        val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,          bins, metadata.unorderedFeatures)        nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)      }      agg    }

首先调用函数predictNodeIndex计算nodeIndex,如果是首轮或者叶子节点,直接返回node.id;如果不是首轮,因为传入的是每棵树的root node,就从root node开始,逐渐往下判断该point应该是属于哪个node的,因为我们已经对node进行了分裂,这里其实实现了样本的划分。举个栗子,当前node如果是root的左孩子节点,而point预测节点应该属于右孩子,则调用nodeBinSepOp时就直接返回了,不会将这个point统计进去,用不大的时间换取样本集划分的空间,还是比较巧妙的。

/**   * Get the node index corresponding to this data point.   * This function mimics prediction, passing an example from the root node down to a leaf   * or unsplit node; that node's index is returned.   *   * @param node  Node in tree from which to classify the given data point.   * @param binnedFeatures  Binned feature vector for data point.   * @param bins possible bins for all features, indexed (numFeatures)(numBins)   * @param unorderedFeatures  Set of indices of unordered features.   * @return  Leaf index if the data point reaches a leaf.   *          Otherwise, last node reachable in tree matching this example.   *          Note: This is the global node index, i.e., the index used in the tree.   *                This index is different from the index used during training a particular   *                group of nodes on one call to [[findBestSplits()]].   */  private def predictNodeIndex(      node: Node,      binnedFeatures: Array[Int],      bins: Array[Array[Bin]],      unorderedFeatures: Set[Int]): Int = {    if (node.isLeaf || node.split.isEmpty) {      // Node is either leaf, or has not yet been split.      node.id    } else {    //判断point属于当前node的左孩子还是右孩子      val featureIndex = node.split.get.feature      val splitLeft = node.split.get.featureType match {        case Continuous => {          val binIndex = binnedFeatures(featureIndex)          val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold          // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold]          // We do not need to check lowSplit since bins are separated by splits.          featureValueUpperBound <= node.split.get.threshold        }        case Categorical => {          val featureValue = binnedFeatures(featureIndex)          node.split.get.categories.contains(featureValue)        }        case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.")      }      if (node.leftNode.isEmpty || node.rightNode.isEmpty) {      //下面还有完整的左右孩子node,递归判断        // Return index from next layer of nodes to train        if (splitLeft) {          Node.leftChildIndex(node.id)        } else {          Node.rightChildIndex(node.id)        }      } else {        if (splitLeft) {          predictNodeIndex(node.leftNode.get, binnedFeatures, bins, unorderedFeatures)        } else {          predictNodeIndex(node.rightNode.get, binnedFeatures, bins, unorderedFeatures)        }      }    }  }

然后调用nodeBinSeqOp函数

/**     * Performs a sequential aggregation over a partition for a particular tree and node.     *     * For each feature, the aggregate sufficient statistics are updated for the relevant     * bins.     *     * @param treeIndex Index of the tree that we want to perform aggregation for.     * @param nodeInfo The node info for the tree node.     * @param agg Array storing aggregate calculation, with a set of sufficient statistics     *            for each (node, feature, bin).     * @param baggedPoint Data point being aggregated.     */    def nodeBinSeqOp(        treeIndex: Int,        nodeInfo: RandomForest.NodeIndexInfo,        agg: Array[DTStatsAggregator],        baggedPoint: BaggedPoint[TreePoint]): Unit = {      if (nodeInfo != null) {      //node的groupIndex,见前文        val aggNodeIndex = nodeInfo.nodeIndexInGroup        //node使用的特征子集        val featuresForNode = nodeInfo.featureSubset        //取样本在这棵树中出现的次数 0/1/k        val instanceWeight = baggedPoint.subsampleWeights(treeIndex)        if (metadata.unorderedFeatures.isEmpty) {          orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)        } else {          mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,            metadata.unorderedFeatures, instanceWeight, featuresForNode)        }      }    }

函数的入参是treeIndex,该node的NodeIndexInfo结构,所有node的累加器数组,样本。本函数是针对单个node的操作,这里可以看到取node对应的aggregator就是通过NodeIndexInfo的第一个成员nodeIndexInGroup作为agg数组的index。
如果不包含无序特征,调用orderedBinSeqOp函数

 /**   * Helper for binSeqOp, for regression and for classification with only ordered features.   *   * For each feature, the sufficient statistics of one bin are updated.   *   * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for   *             each (feature, bin).   * @param treePoint  Data point being aggregated.   * @param instanceWeight  Weight (importance) of instance in dataset.   */  private def orderedBinSeqOp(      agg: DTStatsAggregator, //node的agg      treePoint: TreePoint,      instanceWeight: Double,      featuresForNode: Option[Array[Int]]): Unit = {    val label = treePoint.label    // Iterate over features.    if (featuresForNode.nonEmpty) {      // Use subsampled features      var featureIndexIdx = 0      while (featureIndexIdx < featuresForNode.get.size) {      //连续特征:离散化后的index      //离散特征:featureValue.toInt        val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))        agg.update(featureIndexIdx, binIndex, label, instanceWeight)        featureIndexIdx += 1      }    } else {      // Use all features      val numFeatures = agg.metadata.numFeatures      var featureIndex = 0      while (featureIndex < numFeatures) {        val binIndex = treePoint.binnedFeatures(featureIndex)        agg.update(featureIndex, binIndex, label, instanceWeight)        featureIndex += 1      }    }  }

函数中区分了是否使用了全部特征,区别仅在于如果使用了部分特征(特征抽样),需要先在featuresForNode中取得特征的实际index。
函数其实就是取出样本的使用特征,特征值,label和weight,更新到aggregator中,更新逻辑我们在前文已经说明过了。
包含了无序离散特征,则使用mixedBinSeqOp,只有无序离散特征处理方法不同于orderedBinSeqOp函数

// Unordered featureval featureValue = treePoint.binnedFeatures(featureIndex)//找到特征值对应的allStats中的范围//特征起始位置从featureOffsets中取得,长度是bins的个数乘以分类个数,2*(2^(M-1)-1)*statsSize,//每一个split将样本集分成2部分,allStats中左边部分连续存放,右半部分连续存放,而不是左右一起存放。//因此,左边的起始位置直接可以从featureOffsets中获取,右边起始位置是(2^(M-1)-1)*statsSizeval (leftNodeFeatureOffset, rightNodeFeatureOffset) =agg.getLeftRightFeatureOffsets(featureIndexIdx)// Update the left or right bin for each split.val numSplits = agg.metadata.numSplits(featureIndex)var splitIndex = 0while (splitIndex < numSplits) {    //split中的categories中包含左半边特征值组合,splitIndex相当于其离散化后的特征index    if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) {    agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,              instanceWeight)    } else {        agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label,              instanceWeight)    }          splitIndex += 1}
6.4.1.3. 全局统计
partitionAggregates.reduceByKey((a, b) => a.merge(b))

就是将所有存在allStats中的分区统计结果逐个对应相加得到全局统计结果。

6.4.2. bestSplits

获得所有的统计后,就可以遍历所有的特征,计算impurity gain,确定最佳的split。

val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b))    .map { case (nodeIndex, aggStats) =>         val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures =>        nodeToFeatures(nodeIndex)         }         // find best split for each node        val (split: Split, stats: InformationGainStats, predict: Predict) =            binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))          (nodeIndex, (split, stats, predict))        }.collectAsMap()

对每个node其中调用了binsToBestSplit函数,下面进行详细说明。

6.4.2.1 init

函数首先获取node在树的第几层,树结构如图
这里写图片描述
树的id如图所示,判断node在第几层只需要判断id的二进制表示的最高位的1在第几位即可,比如6的二进制表示是110,最高位的1是在第3位,则其在第3层。
然后获取当前node的预测值和impurity

// calculate predict and impurity if current node is top node    val level = Node.indexToLevel(node.id)    var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) {      None    } else {      Some((node.predict, node.impurity))    }
6.4.2.2 连续特征

对于连续特征而言,当取其某个特征值为best split后,node的样本会被分成大于该特征值和小于等于该特征值两部分,需要分别统计两部分的class分布情况;另一方面,我们查找best,因此要遍历所有特征值的情况,一种巧妙的方法是,从左边开始逐次累积统计数据,需要从某个特征值作为split时,当前累计值就是左边小于等于的情况,用最右的值减去左边就是右边的情况。
这里写图片描述
例如上图中的情况,是某特征6个特征值分布情况,第一行是左累计,第二行是原始分布,当以v2作为split时,左边的分布就是c0:8,c1:5,右边是v6的分布减去v2,c0:19-8=11,c1:14-5=9。

if (binAggregates.metadata.isContinuous(featureIndex)) {        // Cumulative sum (scanLeft) of bin statistics.        // Afterwards, binAggregates for a bin is the sum of aggregates for        // that bin + all preceding bins.        //如上所述,累计        val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)        var splitIndex = 0        while (splitIndex < numSplits) {            binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)            splitIndex += 1        }        // Find best split.        val (bestFeatureSplitIndex, bestFeatureGainStats) =            Range(0, numSplits).map { case splitIdx =>              val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)              val rightChildStats =                binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)              rightChildStats.subtract(leftChildStats)              //获得node的impurity,level==0时,需要根据当前class的分布计算              predictWithImpurity = Some(predictWithImpurity.getOrElse(                calculatePredictImpurity(leftChildStats, rightChildStats)))              val gainStats = calculateGainForSplit(leftChildStats,                rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)              (splitIdx, gainStats)            }.maxBy(_._2.gain)        (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)

计算split分裂node的impurity增益时,调用了calculateGainForSplit函数,其中分别计算了左右的增益,然后概率合并,并计算了左右的预测值,代码比较简单,这里不再赘述。

6.4.2.3. Unordered categorical feature

只有获取左右class的统计情况方法不一致,其他是一样的。

6.4.2.4. Ordered categorical feature

对于连续特征,特征值或者是binIndex是有序的,或者说其数值可以排序,因此如果某个特征值被当做split,分隔的就是左右两部分;对于无序离散特征,其被split分隔后特征值属于哪个bin是确定的;对于有序离散特征,其特征值代表一定次序关系,但是不具有绝对大小的含义,其处理方法可以近似按照连续特征的方法处理,但是spark这里处理了下,可能更优点。
spark首先会确定一个centroid,然后特征会按这个排序,这个相当于连续特征的binIndex。例如centroid如果取每个特征值中class1的个数,假设有特征值0,1,2,3,class1的个数分别为4,2,1,3,其中如果按照连续特征的处理方法,假设用1作为node的分裂点,计算impurity gain的时候分成0,1和2,3两部分统计。如果按照centroid的方法,其特征值排序次序应该是2,1,3,0,以1作为分裂点,会被分成2,1和3,0两部分。

// Ordered categorical featureval nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)val numBins = binAggregates.metadata.numBins(featureIndex)/* Each bin is one category (feature value).* The bins are ordered based on centroidForCategories, and this ordering determines which* splits are considered.  (With K categories, we consider K - 1 possible splits.)** centroidForCategories is a list: (category, centroid)*/val centroidForCategories = Range(0, numBins).map { case featureValue =>    val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)    val centroid = if (categoryStats.count != 0) {        if (binAggregates.metadata.isMulticlass) {        // For categorical variables in multiclass classification,        // the bins are ordered by the impurity of their corresponding labels.            categoryStats.calculate()        } else if (binAggregates.metadata.isClassification) {        // For categorical variables in binary classification,        // the bins are ordered by the count of class 1.            categoryStats.stats(1)        } else {            // For categorical variables in regression,            // the bins are ordered by the prediction.            categoryStats.predict        }    } else {        Double.MaxValue    }    (featureValue, centroid)}logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))// bins sorted by centroidsval categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)

上面的代码为不同的情况设置不同的centroid的选取方法。如果是多分类,使用impurity;如果是二分类,使用class1的个数;如果是回归,使用预测值(实际是均值)。然后将特征值按centroid重排序。
下面的处理基本与连续特征类似,先按排序次序累计,然后计算左右的impurity,计算impurity gain。由于要返回split,之前离散特征的split返回的是空Array,这里构造了split,第四个参数中加入了实际的特征值,类比unordered的情况。

计算完完所有的特征的gain,就可以选取最大增益时的split,最后collectAsMap,key是nodeIndex,value是split, InfomationGainStats,predict的三元组。

6.4.3. node分裂

计算完节点的best split,就要根据这个split进行node的分裂,包括当前节点的一些属性完善,左右孩子节点的构造等。

// Iterate over all nodes in this group.    nodesForGroup.foreach { case (treeIndex, nodesForTree) =>      nodesForTree.foreach { node =>        val nodeIndex = node.id        val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)        val aggNodeIndex = nodeInfo.nodeIndexInGroup        //从刚刚计算的best split中获取相关数据        val (split: Split, stats: InformationGainStats, predict: Predict) =          nodeToBestSplits(aggNodeIndex)        logDebug("best split = " + split)        // Extract info for this node.  Create children if not leaf.        //截止条件        val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth)        assert(node.id == nodeIndex)        node.predict = predict        node.isLeaf = isLeaf        node.stats = Some(stats)        node.impurity = stats.impurity        logDebug("Node = " + node)        //如果不是叶子节点,需要构造左右孩子节点        if (!isLeaf) {          node.split = Some(split)          //叶子节点的depth,当前level+1          val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth          //左右孩子节点是否是叶子节点          val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)          val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)          //构造左右孩子节点          node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex),            stats.leftPredict, stats.leftImpurity, leftChildIsLeaf))          node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex),            stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))          if (nodeIdCache.nonEmpty) {            val nodeIndexUpdater = NodeIndexUpdater(              split = split,              nodeIndex = nodeIndex)            nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater)          }        //如果不是叶子节点,加入到nodeQueue待分裂队列中          // enqueue left child and right child if they are not leaves          if (!leftChildIsLeaf) {            nodeQueue.enqueue((treeIndex, node.leftNode.get))          }          if (!rightChildIsLeaf) {            nodeQueue.enqueue((treeIndex, node.rightNode.get))          }          logDebug("leftChildIndex = " + node.leftNode.get.id +            ", impurity = " + stats.leftImpurity)          logDebug("rightChildIndex = " + node.rightNode.get.id +            ", impurity = " + stats.rightImpurity)        }      }    }

这里将当前节点的左右孩子节点继续加入nodeQueue中,这里面放的是需要继续分裂的节点,至此本轮的findBestSplits就完成了。

// Choose node splits, and enqueue new nodes as needed.timer.start("findBestSplits")DecisionTree.findBestSplits(baggedInput,    metadata, topNodes, nodesForGroup,    treeToNodeToIndexInfo, splits, bins, nodeQueue,    timer, nodeIdCache = nodeIdCache)timer.stop("findBestSplits")

6.5. 循环训练

上节我们说到最后待分裂的节点会加入到nodeQueue中,回到RandomForest.run函数中

while (nodeQueue.nonEmpty) {      // Collect some nodes to split, and choose features for each node (if subsampling).      // Each group of nodes may come from one or multiple trees, and at multiple levels.      val (nodesForGroup, treeToNodeToIndexInfo) =        RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)      // Sanity check (should never occur):      assert(nodesForGroup.size > 0,        s"RandomForest selected empty nodesForGroup.  Error for unknown reason.")      // Choose node splits, and enqueue new nodes as needed.      timer.start("findBestSplits")      DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,        treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)      timer.stop("findBestSplits")    }

当有非叶子节点不断加入nodeQueue中,这里不断分裂出节点,直到所有节点触发截止条件。

原创粉丝点击