spark mllib源码分析之随机森林(Random Forest)(四)
来源:互联网 发布:网盘搜索引擎 知乎 编辑:程序博客网 时间:2024/06/08 04:03
spark源码分析之随机森林(Random Forest)(一)
spark源码分析之随机森林(Random Forest)(二)
spark源码分析之随机森林(Random Forest)(三)
spark源码分析之随机森林(Random Forest)(五)
6.4. node分裂
逻辑主要在DecisionTree.findBestSplits函数中,是RF训练最核心的部分
DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
6.4.1. 数据统计
数据统计分成两部分,先在各个partition上分别统计,再累积各partition成全局统计。
6.4.1.1. 取出node的特征子集
val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
取出各node的特征子集,如果不需要抽样则为None;否则返回Map[Int, Array[Int]],其实就是将之前treeToNodeToIndexInfo中的NodeIndexInfo转换为map结构,将其作为广播变量nodeToFeaturesBc。
6.4.1.2. 分区统计
一系列函数的调用链,我们逐层分析
val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points => // Construct a nodeStatsAggregators array to hold node aggregate stats, // each node will have a nodeStatsAggregator val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => Some(nodeToFeatures(nodeIndex)) } new DTStatsAggregator(metadata, featuresForNode) } // iterator all instances in current partition and update aggregate stats points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _)) // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, // which can be combined with other partition using `reduceByKey` nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator } } else { input.mapPartitions { points => // Construct a nodeStatsAggregators array to hold node aggregate stats, // each node will have a nodeStatsAggregator val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => Some(nodeToFeatures(nodeIndex)) } new DTStatsAggregator(metadata, featuresForNode) } // iterator all instances in current partition and update aggregate stats points.foreach(binSeqOp(nodeStatsAggregators, _)) // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, // which can be combined with other partition using `reduceByKey` nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator } }
首先对每个partition构造一个DTStatsAggregator数组,长度是node的个数,注意这里实际使用的是数组,node怎样与自己的aggregator的对应?前面我们提到NodeIndexInfo的第一个成员是groupIndex,其值就是node的次序,和这里aggregator数组index其实是对应的,也就是说可以从NodeIndexInfo中取得groupIndex,然后作为数组index取得对应node的agg。DTStatsAggregator的入参是metadata和每个node的特征子集。然后将每个点统计到DTStatsAggregator中,其中调用了内部函数binSeqOp,
/** * Performs a sequential aggregation over a partition. * * Each data point contributes to one node. For each feature, * the aggregate sufficient statistics are updated for the relevant bins. * * @param agg Array storing aggregate calculation, with a set of sufficient statistics for * each (node, feature, bin). * @param baggedPoint Data point being aggregated. * @return agg */ def binSeqOp( agg: Array[DTStatsAggregator], baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = { //对每个node treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, bins, metadata.unorderedFeatures) nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) } agg }
首先调用函数predictNodeIndex计算nodeIndex,如果是首轮或者叶子节点,直接返回node.id;如果不是首轮,因为传入的是每棵树的root node,就从root node开始,逐渐往下判断该point应该是属于哪个node的,因为我们已经对node进行了分裂,这里其实实现了样本的划分。举个栗子,当前node如果是root的左孩子节点,而point预测节点应该属于右孩子,则调用nodeBinSepOp时就直接返回了,不会将这个point统计进去,用不大的时间换取样本集划分的空间,还是比较巧妙的。
/** * Get the node index corresponding to this data point. * This function mimics prediction, passing an example from the root node down to a leaf * or unsplit node; that node's index is returned. * * @param node Node in tree from which to classify the given data point. * @param binnedFeatures Binned feature vector for data point. * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param unorderedFeatures Set of indices of unordered features. * @return Leaf index if the data point reaches a leaf. * Otherwise, last node reachable in tree matching this example. * Note: This is the global node index, i.e., the index used in the tree. * This index is different from the index used during training a particular * group of nodes on one call to [[findBestSplits()]]. */ private def predictNodeIndex( node: Node, binnedFeatures: Array[Int], bins: Array[Array[Bin]], unorderedFeatures: Set[Int]): Int = { if (node.isLeaf || node.split.isEmpty) { // Node is either leaf, or has not yet been split. node.id } else { //判断point属于当前node的左孩子还是右孩子 val featureIndex = node.split.get.feature val splitLeft = node.split.get.featureType match { case Continuous => { val binIndex = binnedFeatures(featureIndex) val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold] // We do not need to check lowSplit since bins are separated by splits. featureValueUpperBound <= node.split.get.threshold } case Categorical => { val featureValue = binnedFeatures(featureIndex) node.split.get.categories.contains(featureValue) } case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.") } if (node.leftNode.isEmpty || node.rightNode.isEmpty) { //下面还有完整的左右孩子node,递归判断 // Return index from next layer of nodes to train if (splitLeft) { Node.leftChildIndex(node.id) } else { Node.rightChildIndex(node.id) } } else { if (splitLeft) { predictNodeIndex(node.leftNode.get, binnedFeatures, bins, unorderedFeatures) } else { predictNodeIndex(node.rightNode.get, binnedFeatures, bins, unorderedFeatures) } } } }
然后调用nodeBinSeqOp函数
/** * Performs a sequential aggregation over a partition for a particular tree and node. * * For each feature, the aggregate sufficient statistics are updated for the relevant * bins. * * @param treeIndex Index of the tree that we want to perform aggregation for. * @param nodeInfo The node info for the tree node. * @param agg Array storing aggregate calculation, with a set of sufficient statistics * for each (node, feature, bin). * @param baggedPoint Data point being aggregated. */ def nodeBinSeqOp( treeIndex: Int, nodeInfo: RandomForest.NodeIndexInfo, agg: Array[DTStatsAggregator], baggedPoint: BaggedPoint[TreePoint]): Unit = { if (nodeInfo != null) { //node的groupIndex,见前文 val aggNodeIndex = nodeInfo.nodeIndexInGroup //node使用的特征子集 val featuresForNode = nodeInfo.featureSubset //取样本在这棵树中出现的次数 0/1/k val instanceWeight = baggedPoint.subsampleWeights(treeIndex) if (metadata.unorderedFeatures.isEmpty) { orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) } else { mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, metadata.unorderedFeatures, instanceWeight, featuresForNode) } } }
函数的入参是treeIndex,该node的NodeIndexInfo结构,所有node的累加器数组,样本。本函数是针对单个node的操作,这里可以看到取node对应的aggregator就是通过NodeIndexInfo的第一个成员nodeIndexInGroup作为agg数组的index。
如果不包含无序特征,调用orderedBinSeqOp函数
/** * Helper for binSeqOp, for regression and for classification with only ordered features. * * For each feature, the sufficient statistics of one bin are updated. * * @param agg Array storing aggregate calculation, with a set of sufficient statistics for * each (feature, bin). * @param treePoint Data point being aggregated. * @param instanceWeight Weight (importance) of instance in dataset. */ private def orderedBinSeqOp( agg: DTStatsAggregator, //node的agg treePoint: TreePoint, instanceWeight: Double, featuresForNode: Option[Array[Int]]): Unit = { val label = treePoint.label // Iterate over features. if (featuresForNode.nonEmpty) { // Use subsampled features var featureIndexIdx = 0 while (featureIndexIdx < featuresForNode.get.size) { //连续特征:离散化后的index //离散特征:featureValue.toInt val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx)) agg.update(featureIndexIdx, binIndex, label, instanceWeight) featureIndexIdx += 1 } } else { // Use all features val numFeatures = agg.metadata.numFeatures var featureIndex = 0 while (featureIndex < numFeatures) { val binIndex = treePoint.binnedFeatures(featureIndex) agg.update(featureIndex, binIndex, label, instanceWeight) featureIndex += 1 } } }
函数中区分了是否使用了全部特征,区别仅在于如果使用了部分特征(特征抽样),需要先在featuresForNode中取得特征的实际index。
函数其实就是取出样本的使用特征,特征值,label和weight,更新到aggregator中,更新逻辑我们在前文已经说明过了。
包含了无序离散特征,则使用mixedBinSeqOp,只有无序离散特征处理方法不同于orderedBinSeqOp函数
// Unordered featureval featureValue = treePoint.binnedFeatures(featureIndex)//找到特征值对应的allStats中的范围//特征起始位置从featureOffsets中取得,长度是bins的个数乘以分类个数,2*(2^(M-1)-1)*statsSize,//每一个split将样本集分成2部分,allStats中左边部分连续存放,右半部分连续存放,而不是左右一起存放。//因此,左边的起始位置直接可以从featureOffsets中获取,右边起始位置是(2^(M-1)-1)*statsSizeval (leftNodeFeatureOffset, rightNodeFeatureOffset) =agg.getLeftRightFeatureOffsets(featureIndexIdx)// Update the left or right bin for each split.val numSplits = agg.metadata.numSplits(featureIndex)var splitIndex = 0while (splitIndex < numSplits) { //split中的categories中包含左半边特征值组合,splitIndex相当于其离散化后的特征index if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) { agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) } else { agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) } splitIndex += 1}
6.4.1.3. 全局统计
partitionAggregates.reduceByKey((a, b) => a.merge(b))
就是将所有存在allStats中的分区统计结果逐个对应相加得到全局统计结果。
6.4.2. bestSplits
获得所有的统计后,就可以遍历所有的特征,计算impurity gain,确定最佳的split。
val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)) .map { case (nodeIndex, aggStats) => val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures => nodeToFeatures(nodeIndex) } // find best split for each node val (split: Split, stats: InformationGainStats, predict: Predict) = binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) (nodeIndex, (split, stats, predict)) }.collectAsMap()
对每个node其中调用了binsToBestSplit函数,下面进行详细说明。
6.4.2.1 init
函数首先获取node在树的第几层,树结构如图
树的id如图所示,判断node在第几层只需要判断id的二进制表示的最高位的1在第几位即可,比如6的二进制表示是110,最高位的1是在第3位,则其在第3层。
然后获取当前node的预测值和impurity
// calculate predict and impurity if current node is top node val level = Node.indexToLevel(node.id) var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) { None } else { Some((node.predict, node.impurity)) }
6.4.2.2 连续特征
对于连续特征而言,当取其某个特征值为best split后,node的样本会被分成大于该特征值和小于等于该特征值两部分,需要分别统计两部分的class分布情况;另一方面,我们查找best,因此要遍历所有特征值的情况,一种巧妙的方法是,从左边开始逐次累积统计数据,需要从某个特征值作为split时,当前累计值就是左边小于等于的情况,用最右的值减去左边就是右边的情况。
例如上图中的情况,是某特征6个特征值分布情况,第一行是左累计,第二行是原始分布,当以v2作为split时,左边的分布就是c0:8,c1:5,右边是v6的分布减去v2,c0:19-8=11,c1:14-5=9。
if (binAggregates.metadata.isContinuous(featureIndex)) { // Cumulative sum (scanLeft) of bin statistics. // Afterwards, binAggregates for a bin is the sum of aggregates for // that bin + all preceding bins. //如上所述,累计 val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) var splitIndex = 0 while (splitIndex < numSplits) { binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) splitIndex += 1 } // Find best split. val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { case splitIdx => val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) //获得node的impurity,level==0时,需要根据当前class的分布计算 predictWithImpurity = Some(predictWithImpurity.getOrElse( calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIdx, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
计算split分裂node的impurity增益时,调用了calculateGainForSplit函数,其中分别计算了左右的增益,然后概率合并,并计算了左右的预测值,代码比较简单,这里不再赘述。
6.4.2.3. Unordered categorical feature
只有获取左右class的统计情况方法不一致,其他是一样的。
6.4.2.4. Ordered categorical feature
对于连续特征,特征值或者是binIndex是有序的,或者说其数值可以排序,因此如果某个特征值被当做split,分隔的就是左右两部分;对于无序离散特征,其被split分隔后特征值属于哪个bin是确定的;对于有序离散特征,其特征值代表一定次序关系,但是不具有绝对大小的含义,其处理方法可以近似按照连续特征的方法处理,但是spark这里处理了下,可能更优点。
spark首先会确定一个centroid,然后特征会按这个排序,这个相当于连续特征的binIndex。例如centroid如果取每个特征值中class1的个数,假设有特征值0,1,2,3,class1的个数分别为4,2,1,3,其中如果按照连续特征的处理方法,假设用1作为node的分裂点,计算impurity gain的时候分成0,1和2,3两部分统计。如果按照centroid的方法,其特征值排序次序应该是2,1,3,0,以1作为分裂点,会被分成2,1和3,0两部分。
// Ordered categorical featureval nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)val numBins = binAggregates.metadata.numBins(featureIndex)/* Each bin is one category (feature value).* The bins are ordered based on centroidForCategories, and this ordering determines which* splits are considered. (With K categories, we consider K - 1 possible splits.)** centroidForCategories is a list: (category, centroid)*/val centroidForCategories = Range(0, numBins).map { case featureValue => val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) val centroid = if (categoryStats.count != 0) { if (binAggregates.metadata.isMulticlass) { // For categorical variables in multiclass classification, // the bins are ordered by the impurity of their corresponding labels. categoryStats.calculate() } else if (binAggregates.metadata.isClassification) { // For categorical variables in binary classification, // the bins are ordered by the count of class 1. categoryStats.stats(1) } else { // For categorical variables in regression, // the bins are ordered by the prediction. categoryStats.predict } } else { Double.MaxValue } (featureValue, centroid)}logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))// bins sorted by centroidsval categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
上面的代码为不同的情况设置不同的centroid的选取方法。如果是多分类,使用impurity;如果是二分类,使用class1的个数;如果是回归,使用预测值(实际是均值)。然后将特征值按centroid重排序。
下面的处理基本与连续特征类似,先按排序次序累计,然后计算左右的impurity,计算impurity gain。由于要返回split,之前离散特征的split返回的是空Array,这里构造了split,第四个参数中加入了实际的特征值,类比unordered的情况。
计算完完所有的特征的gain,就可以选取最大增益时的split,最后collectAsMap,key是nodeIndex,value是split, InfomationGainStats,predict的三元组。
6.4.3. node分裂
计算完节点的best split,就要根据这个split进行node的分裂,包括当前节点的一些属性完善,左右孩子节点的构造等。
// Iterate over all nodes in this group. nodesForGroup.foreach { case (treeIndex, nodesForTree) => nodesForTree.foreach { node => val nodeIndex = node.id val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) val aggNodeIndex = nodeInfo.nodeIndexInGroup //从刚刚计算的best split中获取相关数据 val (split: Split, stats: InformationGainStats, predict: Predict) = nodeToBestSplits(aggNodeIndex) logDebug("best split = " + split) // Extract info for this node. Create children if not leaf. //截止条件 val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth) assert(node.id == nodeIndex) node.predict = predict node.isLeaf = isLeaf node.stats = Some(stats) node.impurity = stats.impurity logDebug("Node = " + node) //如果不是叶子节点,需要构造左右孩子节点 if (!isLeaf) { node.split = Some(split) //叶子节点的depth,当前level+1 val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth //左右孩子节点是否是叶子节点 val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) //构造左右孩子节点 node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex), stats.leftPredict, stats.leftImpurity, leftChildIsLeaf)) node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex), stats.rightPredict, stats.rightImpurity, rightChildIsLeaf)) if (nodeIdCache.nonEmpty) { val nodeIndexUpdater = NodeIndexUpdater( split = split, nodeIndex = nodeIndex) nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater) } //如果不是叶子节点,加入到nodeQueue待分裂队列中 // enqueue left child and right child if they are not leaves if (!leftChildIsLeaf) { nodeQueue.enqueue((treeIndex, node.leftNode.get)) } if (!rightChildIsLeaf) { nodeQueue.enqueue((treeIndex, node.rightNode.get)) } logDebug("leftChildIndex = " + node.leftNode.get.id + ", impurity = " + stats.leftImpurity) logDebug("rightChildIndex = " + node.rightNode.get.id + ", impurity = " + stats.rightImpurity) } } }
这里将当前节点的左右孩子节点继续加入nodeQueue中,这里面放的是需要继续分裂的节点,至此本轮的findBestSplits就完成了。
// Choose node splits, and enqueue new nodes as needed.timer.start("findBestSplits")DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)timer.stop("findBestSplits")
6.5. 循环训练
上节我们说到最后待分裂的节点会加入到nodeQueue中,回到RandomForest.run函数中
while (nodeQueue.nonEmpty) { // Collect some nodes to split, and choose features for each node (if subsampling). // Each group of nodes may come from one or multiple trees, and at multiple levels. val (nodesForGroup, treeToNodeToIndexInfo) = RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) // Sanity check (should never occur): assert(nodesForGroup.size > 0, s"RandomForest selected empty nodesForGroup. Error for unknown reason.") // Choose node splits, and enqueue new nodes as needed. timer.start("findBestSplits") DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache) timer.stop("findBestSplits") }
当有非叶子节点不断加入nodeQueue中,这里不断分裂出节点,直到所有节点触发截止条件。
- spark mllib源码分析之随机森林(Random Forest)(四)
- spark mllib源码分析之随机森林(Random Forest)(一)
- spark mllib源码分析之随机森林(Random Forest)(二)
- spark mllib源码分析之随机森林(Random Forest)(三)
- spark mllib源码分析之随机森林(Random Forest)(五)
- 随机森林(Random Forest)算法原理及Spark MLlib调用实例(Scala/Java/python)
- 随机森林回归(Random Forest)算法原理及Spark MLlib调用实例(Scala/Java/python)
- 随机森林(Random Forest)
- 随机森林(Random Forest)
- random forest(随机森林)
- 随机森林(Random Forest)
- 随机森林Random Forest
- 随机森林--Random Forest
- 随机森林Random Forest
- 随机森林(Random Forest)
- random forest(随机森林)
- 随机森林Random Forest
- Random Forest(随机森林)
- 游标概述
- BZOJ4897 [Thu Summer Camp2016]成绩单
- Mybatis自查询递归查找子菜单
- eclipse默认指向 WebContent 目录 修改为 webRoot
- Docker网络配置(一)
- spark mllib源码分析之随机森林(Random Forest)(四)
- Hibernate笔记(一)
- 运算符重载参数的顺序对运算是否有影响
- JavaScript RegExp.$1...$9 属性详解
- Mysql创建、删除用户
- bitmap与file之间转换使用
- java socketNIO demo
- nodejs base64 编解码
- Servlet过滤器与封装器